Advertisement
Guest User

Untitled

a guest
Jul 24th, 2017
43
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.59 KB | None | 0 0
  1. #!/usr/bin/env python
  2. from mpi4py import MPI
  3. from baselines.common import set_global_seeds
  4. import os.path as osp
  5. import gym, logging
  6. from baselines import logger
  7. from baselines.pposgd.mlp_policy import MlpPolicy
  8. from baselines.common.mpi_fork import mpi_fork
  9. from baselines import bench
  10. from baselines.trpo_mpi import trpo_mpi
  11. import sys
  12. num_cpu=1
  13.  
  14. from gym import wrappers
  15.  
  16. def train(env_id, num_timesteps, seed):
  17. whoami = mpi_fork(num_cpu)
  18. if whoami == "parent":
  19. return
  20. import baselines.common.tf_util as U
  21. logger.session().__enter__()
  22. sess = U.single_threaded_session()
  23. sess.__enter__()
  24.  
  25. rank = MPI.COMM_WORLD.Get_rank()
  26. if rank != 0:
  27. logger.set_level(logger.DISABLED)
  28. workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
  29. set_global_seeds(workerseed)
  30. env = gym.make(env_id)
  31. env = wrappers.Monitor(env, directory='gym_mon_trpo', force=True)
  32. def policy_fn(name, ob_space, ac_space):
  33. return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
  34. hid_size=64, num_hid_layers=2)
  35. #env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
  36. #env = bench.Monitor(env, osp.join('./', "trpo_monitor.json"))
  37. env.seed(workerseed)
  38. gym.logger.setLevel(logging.WARN)
  39.  
  40. trpo_mpi.learn(env, policy_fn, timesteps_per_batch=15000, max_kl=0.01, cg_iters=500, cg_damping=0.1,
  41. max_timesteps=num_timesteps, gamma=0.995, lam=0.97, vf_iters=5, vf_stepsize=1e-3)
  42. env.close()
  43.  
  44.  
  45. def main():
  46. train('InvertedPendulum-v1', num_timesteps=1e6, seed=0)
  47.  
  48.  
  49. if __name__ == '__main__':
  50. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement