Advertisement
Guest User

Untitled

a guest
Aug 26th, 2019
120
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.57 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import copy
  4.  
  5.  
  6. class EpsGreedyQPolicy:
  7. def __init__(self, epsilon=.1, decay_rate=1):
  8. self.epsilon = epsilon
  9. self.decay_rate = decay_rate
  10.  
  11. def select_action(self, q_values):
  12. assert q_values.ndim == 1
  13. nb_actions = q_values.shape[0]
  14.  
  15. if np.random.uniform() < self.epsilon: # random行動
  16. action = np.random.random_integers(0, nb_actions-1)
  17. else: # greedy 行動
  18. action = np.argmax(q_values)
  19.  
  20. return action
  21.  
  22.  
  23. class SARSAAgent:
  24. """
  25. sarsa
  26. """
  27. def __init__(self, alpha=.2, policy=None, gamma=.99, actions=None, observation=None, alpha_decay_rate=None):
  28. self.alpha = alpha
  29. self.gamma = gamma
  30. self.policy = policy
  31. self.reward_history = []
  32. self.actions = actions
  33. self.alpha_decay_rate = alpha_decay_rate
  34. self.state = str(observation)
  35. self.previous_state = None
  36. self.ini_state = str(observation) # 初期状態の保存
  37. self.previous_action_id = None
  38. self.recent_action_id = 0
  39. self.q_values = self._init_q_values()
  40. self.training = True
  41.  
  42. def _init_q_values(self):
  43. """
  44. Q テーブルの初期化
  45. """
  46. q_values = {}
  47. q_values[self.state] = np.repeat(0.0, len(self.actions))
  48. return q_values
  49.  
  50. def init_state(self):
  51. """
  52. 状態を初期状態に(再スタート用)
  53. """
  54. self.previous_state = None
  55. self.state = copy.deepcopy(self.ini_state)
  56. return self.state
  57.  
  58. def init_policy(self, policy):
  59. self.policy = policy
  60.  
  61. def act(self):
  62. action = self.actions[self.recent_action_id]
  63. return action
  64.  
  65. def select_action(self):
  66. action_id = self.policy.select_action(self.q_values[self.state])
  67. return action_id
  68.  
  69. def observe(self, next_state, reward=None):
  70. """
  71. 次の状態の観測
  72. """
  73. next_state = str(next_state)
  74. if next_state not in self.q_values: # 始めて訪れる状態であれば
  75. self.q_values[next_state] = np.repeat(0.0, len(self.actions))
  76.  
  77. self.previous_state = copy.deepcopy(self.state)
  78. self.state = next_state
  79.  
  80. if self.training and reward is not None:
  81. self.reward_history.append(reward)
  82. self.learn(reward)
  83.  
  84. def learn(self, reward):
  85. """
  86. 報酬の獲得とQ値の更新
  87. """
  88. self.reward_history.append(reward)
  89. self.previous_action_id = copy.deepcopy(self.recent_action_id)
  90. self.recent_action_id = self.select_action()
  91. self.q_values[self.previous_state][self.previous_action_id] = self._update_q_value(reward)
  92.  
  93. def _update_q_value(self, reward):
  94. """
  95. Q値の更新量の計算
  96. """
  97. q = self.q_values[self.previous_state][self.previous_action_id] # Q(s, a)
  98. q2 = self.q_values[self.state][self.recent_action_id] # Q(s', a')
  99.  
  100. # Q(s, a) = Q(s, a) + alpha*(r+gamma*Q(s', a')-Q(s, a))
  101. updated_q_value = q + (self.alpha * (reward + (self.gamma * q2) - q))
  102.  
  103. return updated_q_value
  104.  
  105. def update_hyper_parameters(self):
  106. """
  107. ハイパーパラメータの更新
  108. """
  109. self.decay_alpha()
  110. self.policy.decay_epsilon()
  111.  
  112.  
  113. class GridWorld:
  114.  
  115. def __init__(self):
  116.  
  117. self.map = [[0, 2, 0, 1],
  118. [0, 0, 0, 2],
  119. [0, 0, 2, 0],
  120. [0, 2, 0, 0],
  121. [0, 0, 0, 0]]
  122.  
  123. self.start_pos = 0, 4 # エージェントのスタート地点(x, y)
  124. self.agent_pos = copy.deepcopy(self.start_pos) # エージェントがいる地点
  125.  
  126. self.filed_type = {
  127. "N": 0, #通常
  128. "G": 1, #ゴール
  129. "W": 2, #壁
  130. }
  131.  
  132. self.actions = {
  133. "UP": 0,
  134. "DOWN": 1,
  135. "LEFT": 2,
  136. "RIGHT": 3
  137. }
  138.  
  139. def step(self, action):
  140. """
  141. 行動の実行
  142. 状態, 報酬、ゴールしたかを返却
  143. """
  144. to_x, to_y = copy.deepcopy(self.agent_pos)
  145.  
  146. # 移動可能かどうかの確認。移動不可能であれば、ポジションはそのままにマイナス報酬
  147. if self._is_possible_action(to_x, to_y, action) == False:
  148. return self.agent_pos, -1, False
  149.  
  150. if action == self.actions["UP"]:
  151. to_y += -1
  152. elif action == self.actions["DOWN"]:
  153. to_y += 1
  154. elif action == self.actions["LEFT"]:
  155. to_x += -1
  156. elif action == self.actions["RIGHT"]:
  157. to_x += 1
  158.  
  159. is_goal = self._is_goal(to_x, to_y) # ゴールしているかの確認
  160. reward = self._compute_reward(to_x, to_y)
  161. self.agent_pos = to_x, to_y
  162. return self.agent_pos, reward, is_goal
  163.  
  164. def _is_goal(self, x, y):
  165. """
  166. x, yがゴール地点かの判定
  167. """
  168. if self.map[y][x] == self.filed_type["G"]:
  169. return True
  170. else:
  171. return False
  172.  
  173. def _is_wall(self, x, y):
  174. """
  175. x, yが壁かどうかの確認
  176. """
  177. if self.map[y][x] == self.filed_type["W"]:
  178. return True
  179. else:
  180. return False
  181.  
  182. def _is_possible_action(self, x, y, action):
  183. """
  184. 実行可能な行動かどうかの判定
  185. """
  186. to_x = x
  187. to_y = y
  188.  
  189. if action == self.actions["UP"]:
  190. to_y += -1
  191. elif action == self.actions["DOWN"]:
  192. to_y += 1
  193. elif action == self.actions["LEFT"]:
  194. to_x += -1
  195. elif action == self.actions["RIGHT"]:
  196. to_x += 1
  197.  
  198. if len(self.map) <= to_y or 0 > to_y:
  199. return False
  200. elif len(self.map[0]) <= to_x or 0 > to_x:
  201. return False
  202. elif self._is_wall(to_x, to_y):
  203. return False
  204.  
  205. return True
  206.  
  207. def _compute_reward(self, x, y):
  208. if self.map[y][x] == self.filed_type["N"]:
  209. return 0
  210. elif self.map[y][x] == self.filed_type["G"]:
  211. return 100
  212.  
  213. def reset(self):
  214. self.agent_pos = self.start_pos
  215. return self.start_pos
  216.  
  217.  
  218. if __name__ == '__main__':
  219. grid_env = GridWorld() # grid worldの環境の初期化
  220. ini_state = grid_env.start_pos # 初期状態(エージェントのスタート地点の位置)
  221. policy = EpsGreedyQPolicy(epsilon=.01, decay_rate=.99) # 方策の初期化。ここではε-greedy
  222. agent = SARSAAgent(policy=policy, actions=np.arange(4), observation=ini_state) # sarsa エージェントの初期化
  223. nb_episode = 100 #エピソード数
  224. rewards = [] # 評価用報酬の保存
  225. is_goal = False # エージェントがゴールしてるかどうか?
  226. for episode in range(nb_episode):
  227. episode_reward = [] # 1エピソードの累積報酬
  228. while(is_goal == False): # ゴールするまで続ける
  229. action = agent.act() # 行動選択
  230. state, reward, is_goal = grid_env.step(action)
  231. agent.observe(state, reward) # 状態と報酬の観測
  232. episode_reward.append(reward)
  233. rewards.append(np.sum(episode_reward)) # このエピソードの平均報酬を与える
  234. state = grid_env.reset() # 初期化
  235. agent.observe(state) # エージェントを初期位置に
  236. is_goal = False
  237.  
  238. # 結果のプロット
  239. plt.plot(np.arange(nb_episode), rewards)
  240. plt.xlabel("episode")
  241. plt.ylabel("reward")
  242. plt.savefig("result.jpg")
  243. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement