Advertisement
Guest User

Untitled

a guest
Dec 9th, 2016
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.17 KB | None | 0 0
  1. import argparse
  2. import gym
  3. import six
  4. #import numpy as np
  5. import random
  6.  
  7. import chainer
  8. from chainer import functions as F
  9. from chainer import cuda
  10. from chainer import links as L
  11.  
  12. gpu_device = 0
  13. cuda.get_device(gpu_device).use()
  14. xp = chainer.cuda.cupy
  15.  
  16. class Agent(chainer.Chain):
  17. gamma = 0.99
  18. initial_epsilon = 1
  19. epsilon_reduction = 0.001
  20. min_epsilon = 0.01
  21.  
  22. def __init__(self, input_size, output_size, hidden_size):
  23. initialW = chainer.initializers.HeNormal(0.01)
  24. super(Agent, self).__init__(
  25. fc1=F.Linear(input_size, hidden_size, initialW=initialW),
  26. fc2=F.Linear(hidden_size, hidden_size, initialW=initialW),
  27. fc3=F.Linear(hidden_size, output_size, initialW=initialW),
  28. )
  29. self.epsilon = self.initial_epsilon
  30. self.output_size = output_size
  31.  
  32. def __call__(self, x):
  33. h = F.relu(self.fc1(x))
  34. h = F.relu(self.fc2(h))
  35. h = self.fc3(h)
  36. return h
  37.  
  38. def randomize_action(self, action):
  39. if random.random() < self.epsilon:
  40. return random.randint(0, self.output_size - 1)
  41. return action
  42.  
  43. def reduce_epsilon(self):
  44. self.epsilon = (self.epsilon - self.min_epsilon) * (1 - self.epsilon_reduction) + self.min_epsilon
  45.  
  46. def adjust_reward(self, state, reward, done):
  47. return reward
  48.  
  49. def normalize_state(self, state):
  50. return xp.asarray(state, dtype=xp.float32)
  51.  
  52. class CartPoleAgent(Agent):
  53. gamma = 0.9
  54. initial_epsilon = 1
  55. min_epsilon = 0.01
  56. epsilon_reduction = 0.001
  57.  
  58. def __init__(self):
  59. super(CartPoleAgent, self).__init__(4, 2, 24)
  60.  
  61. def adjust_reward(self, state, reward, done):
  62. return reward
  63.  
  64. def normalize_state(self, state):
  65. scale = xp.asarray([1 / 2.4, 1 / 4.0, 1 / 0.2, 1 / 3.0], dtype=xp.float32)
  66. return xp.asarray(state, dtype=xp.float32) * scale
  67.  
  68. class MountainCarAgent(Agent):
  69. gamma = 0.99
  70. initial_epsilon = 0.8
  71. min_epsilon = 0.1
  72. epsilon_reduction = 0.0001
  73.  
  74. def __init__(self):
  75. super(MountainCarAgent, self).__init__(2, 3, 64)
  76.  
  77. def adjust_reward(self, state, reward, done):
  78. return reward
  79.  
  80. def normalize_state(self, state):
  81. scale = xp.asarray([1 / 1.2, 1 / 0.07], dtype=xp.float32)
  82. return xp.asarray(state, dtype=xp.float32) * scale
  83.  
  84. class ExperiencePool(object):
  85.  
  86. def __init__(self, size, state_shape):
  87. self.size = size
  88. self.states = xp.zeros(((size,) + state_shape), dtype=xp.float32)
  89. self.actions = xp.zeros((size,), dtype=xp.int32)
  90. self.rewards = xp.zeros((size,), dtype=xp.float32)
  91. self.nexts = xp.zeros((size,), dtype=xp.float32)
  92. self.pos = 0
  93.  
  94. def add(self, state, action, reward, done):
  95. index = self.pos % self.size
  96. self.states[index, ...] = state
  97. self.actions[index] = action
  98. self.rewards[index] = reward
  99. if done:
  100. self.nexts[index] = 0
  101. else:
  102. self.nexts[index] = 1
  103. self.pos += 1
  104.  
  105. def available_size(self):
  106. if self.pos > self.size:
  107. return self.size - 1
  108. return self.pos - 1
  109.  
  110. def __getitem__(self, index):
  111. if self.pos < self.size:
  112. offset = 0
  113. else:
  114. offset = self.pos % self.size - self.size
  115. index += offset
  116. return self.states[index], self.actions[index], self.rewards[index], self.states[index + 1], self.nexts[index]
  117.  
  118. def update(agent, target_agent, optimizer, ex_pool, batch_size):
  119. available_size = ex_pool.available_size()
  120. if available_size < batch_size:
  121. return
  122. indices = xp.random.permutation(available_size)[:batch_size]
  123. data = [ex_pool[i] for i in indices]
  124. state, action, reward, next_state, has_next = zip(*data)
  125. state = xp.asarray(state)
  126. action = xp.asarray(action)
  127. reward = xp.asarray(reward)
  128. next_state = xp.asarray(next_state)
  129. has_next = xp.asarray(has_next)
  130.  
  131. q = F.select_item(agent(state), action)
  132. next_action = xp.argmax(agent(next_state).data, axis=1)
  133. y = reward + agent.gamma * has_next * target_agent(next_state).data[(six.moves.range(len(next_action))), next_action]
  134. loss = F.mean_squared_error(q, y)
  135. agent.cleargrads()
  136. loss.backward()
  137. optimizer.update()
  138.  
  139. def parse_arg():
  140. parser = argparse.ArgumentParser('Open AI Gym learning sample')
  141. parser.add_argument('--env', '-e', type=str, choices=['cart_pole', 'mountain_car'], help='Environment name')
  142. parser.add_argument('--skip_render', '-s', type=int, default=0, help='Episodes nterval to skip rendering')
  143. parser.add_argument('--batch-size', '-b', type=int, default=32, help='Batch size for taining')
  144. parser.add_argument('--pool-size', '-p', type=int, default=2000, help='Experiance pool size')
  145. parser.add_argument('--train-iter', '-t', type=int, default=10, help='Number of training iterations')
  146. parser.add_argument('--episode', type=int, default=1000, help='Number of episodes')
  147. parser.add_argument('--episode-len', type=int, default=1000, help='Length of an episode')
  148. parser.add_argument('--use-double-q', action='store_true', help='Use Double Q-learning')
  149. return parser.parse_args()
  150.  
  151. def main():
  152. args = parse_arg()
  153. episode_num = args.episode
  154. episode_length = args.episode_len
  155. pool_size = args.pool_size
  156. batch_size = args.batch_size
  157. train_num = args.train_iter
  158. update_count = 0
  159. update_agent_interval = 100
  160. use_double_q = args.use_double_q
  161.  
  162. env_name = args.env
  163. if env_name == 'mountain_car':
  164. env = gym.make('MountainCar-v0')
  165. agent = MountainCarAgent()
  166. else:
  167. env = gym.make('CartPole-v0')
  168. agent = CartPoleAgent()
  169. agent.to_gpu(gpu_device)
  170. skip_rendering_interval = args.skip_render
  171.  
  172. if use_double_q:
  173. target_agent = agent.copy()
  174. else:
  175. target_agent = agent
  176. optimizer = chainer.optimizers.Adam()
  177. optimizer.setup(agent)
  178. ex_pool = ExperiencePool(pool_size, env.observation_space.shape)
  179.  
  180. for episode in six.moves.range(episode_num):
  181. raw_state = env.reset()
  182. state = agent.normalize_state(raw_state)
  183. need_render = skip_rendering_interval <= 0 or episode % skip_rendering_interval == 0
  184. for t in six.moves.range(episode_length):
  185. if need_render:
  186. env.render()
  187. action = xp.argmax(agent(xp.expand_dims(state, 0)).data)
  188. action = agent.randomize_action(action)
  189.  
  190. prev_state = state
  191. raw_state, raw_reward, done, info = env.step(action)
  192. reward = agent.adjust_reward(raw_state, raw_reward, done)
  193. state = agent.normalize_state(raw_state)
  194. ex_pool.add(prev_state, action, reward, done or t == episode_length - 1)
  195. for i in six.moves.range(train_num):
  196. update(agent, target_agent, optimizer, ex_pool, batch_size)
  197. update_count += 1
  198. agent.reduce_epsilon()
  199. if use_double_q and update_count % update_agent_interval == 0:
  200. target_agent = agent.copy()
  201. if done:
  202. print('Episode {} finished after {} timesteps'.format(episode + 1, t + 1))
  203. break
  204. if not done:
  205. print('Epsode {} completed'.format(episode + 1))
  206.  
  207. if __name__ == '__main__':
  208. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement