Guest User

Untitled

a guest
Mar 21st, 2018
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.64 KB | None | 0 0
  1. import argparse
  2. import gym
  3. import numpy as np
  4. import os
  5. import tensorflow as tf
  6. import time
  7. import pickle
  8. from tensorforce.agents import PPOAgent
  9. from rl_models.agents import *
  10.  
  11. def parse_args():
  12. parser = argparse.ArgumentParser("Reinforcement Learning experiments for multiagent environments")
  13. # Environment
  14. parser.add_argument("--scenario", type=str, default="vip_rl", help="name of the scenario script")
  15. parser.add_argument("--max-episode-len", type=int, default=200, help="maximum episode length")
  16. parser.add_argument("--num-episodes", type=int, default=60000, help="number of episodes")
  17. parser.add_argument("--agent-type", type=str, default="ppo", help="policy for bodyguards")
  18.  
  19. # Checkpointing
  20. parser.add_argument("--exp-name", type=str, default=None, help="name of the experiment")
  21. parser.add_argument("--save-dir", type=str, default="/tmp/policy/", help="directory in which training state and model should be saved")
  22. parser.add_argument("--save-rate", type=int, default=1000, help="save model once every time this many episodes are completed")
  23. parser.add_argument("--load-dir", type=str, default="", help="directory in which training state and model are loaded")
  24. # Evaluation
  25. parser.add_argument("--restore", action="store_true", default=False)
  26. parser.add_argument("--display", action="store_true", default=True)
  27. parser.add_argument("--benchmark", action="store_true", default=False)
  28. parser.add_argument("--benchmark-iters", type=int, default=100000, help="number of iterations run for benchmarking")
  29. parser.add_argument("--benchmark-dir", type=str, default="./benchmark_files/", help="directory where benchmark data is saved")
  30. parser.add_argument("--plots-dir", type=str, default="./learning_curves/", help="directory where plot data is saved")
  31. return parser.parse_args()
  32.  
  33. def make_env(scenario_name):
  34. from multiagent.environment import MultiAgentEnv
  35. import multiagent.scenarios as scenarios
  36.  
  37. # load scenario from script
  38. scenario = scenarios.load(scenario_name + ".py").Scenario()
  39. # create world
  40. world = scenario.make_world()
  41. # create multiagent environment
  42. return MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, scenario.info, scenario.done)
  43.  
  44. def get_trainers(rl_agent, observation_space_dimension, action_space_dimension, number_of_agents):
  45. trainers = []
  46. for i in range(number_of_agents):
  47. agent = DistributedTrainer(rl_agent+".json",observation_space_dimension[i], action_space_dimension[i])
  48. trainers.append(agent)
  49. return trainers
  50.  
  51. def train(arglist):
  52. env = make_env(arglist.scenario)
  53. obs_shape_n = [env.observation_space[i] for i in range(env.n)]
  54. action_shape_n = [env.action_space[i] for i in range(env.n)]
  55. trainers = get_trainers(arglist.agent_type, obs_shape_n, action_shape_n, env.n)
  56. obs_n = env.reset()
  57. action_n = [agent.action(obs) for agent, obs in zip(trainers,obs_n)]
  58. episode_rewards = [0.0] # sum of rewards for all agents
  59. agent_rewards = [[0.0] for _ in range(env.n)] # individual agent reward
  60. final_ep_rewards = [] # sum of rewards for training curve
  61. final_ep_ag_rewards = [] # agent rewards for training curve
  62. agent_info = [[[]]] # placeholder for benchmarking info
  63. obs_n = env.reset()
  64. episode_step = 0
  65. train_step = 0
  66.  
  67. t_start = time.time()
  68. print('Starting iterations...')
  69. while(True):
  70. action_n = [agent.action(obs) for agent, obs in zip(trainers,obs_n)]
  71. # environment step
  72. new_obs_n, rew_n, done_n, info_n = env.step(action_n)
  73. episode_step += 1
  74. done = all(done_n)
  75. terminal = (episode_step >= arglist.max_episode_len)
  76. obs_n = new_obs_n
  77.  
  78. for i, rew in enumerate(rew_n):
  79. episode_rewards[-1] += rew
  80. agent_rewards[i][-1] += rew
  81. # for i, agent in enumerate(trainers):
  82. # agent.update(rew_n[i], terminal)
  83.  
  84. if done or terminal:
  85. for i, agent in enumerate(trainers):
  86. agent.update(rew_n[i], done or terminal)
  87. obs_n = env.reset()
  88. tf.summary.scalar('episode_reward', tf.constant(episode_rewards[-1]))
  89. episode_step = 0
  90. episode_rewards.append(0)
  91. for a in agent_rewards:
  92. a.append(0)
  93. agent_info.append([[]])
  94.  
  95. # for displaying learned policies
  96. if arglist.display:
  97. env.render()
  98. continue
  99.  
  100. if terminal and (len(episode_rewards) % arglist.save_rate == 0):
  101. print("steps: {}, episodes: {}, mean episode reward: {}, agent episode reward: {}, time: {}".format(train_step, len(episode_rewards), np.mean(episode_rewards[-arglist.save_rate:]),
  102. [np.mean(rew[-arglist.save_rate:]) for rew in agent_rewards], round(time.time()-t_start, 3)))
  103. t_start = time.time()
  104. # Keep track of final episode reward
  105. final_ep_rewards.append(np.mean(episode_rewards[-arglist.save_rate:]))
  106. for rew in agent_rewards:
  107. final_ep_ag_rewards.append(np.mean(rew[-arglist.save_rate:]))
  108.  
  109. if len(episode_rewards) > arglist.num_episodes:
  110. rew_file_name = arglist.plots_dir + arglist.exp_name + '_rewards.pkl'
  111. with open(rew_file_name, 'wb') as fp:
  112. pickle.dump(final_ep_rewards, fp)
  113. agrew_file_name = arglist.plots_dir + arglist.exp_name + '_agrewards.pkl'
  114. with open(agrew_file_name, 'wb') as fp:
  115. pickle.dump(final_ep_ag_rewards, fp)
  116. print('...Finished total of {} episodes.'.format(len(episode_rewards)))
  117. break
  118.  
  119.  
  120.  
  121.  
  122. if __name__ == '__main__':
  123. arglist = parse_args()
  124. train(arglist)
Add Comment
Please, Sign In to add comment