Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- from mpi4py import MPI
- from baselines.common import set_global_seeds
- import os.path as osp
- import gym, logging
- from baselines import logger
- from baselines.pposgd.mlp_policy import MlpPolicy
- from baselines.common.mpi_fork import mpi_fork
- from baselines import bench
- from baselines.trpo_mpi import trpo_mpi
- import sys
- num_cpu=1
- from gym import wrappers
- def train(env_id, num_timesteps, seed):
- whoami = mpi_fork(num_cpu)
- if whoami == "parent":
- return
- import baselines.common.tf_util as U
- logger.session().__enter__()
- sess = U.single_threaded_session()
- sess.__enter__()
- rank = MPI.COMM_WORLD.Get_rank()
- if rank != 0:
- logger.set_level(logger.DISABLED)
- workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
- set_global_seeds(workerseed)
- env = gym.make(env_id)
- env = wrappers.Monitor(env, directory='gym_mon_trpo', force=True)
- def policy_fn(name, ob_space, ac_space):
- return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
- hid_size=64, num_hid_layers=2)
- #env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
- #env = bench.Monitor(env, osp.join('./', "trpo_monitor.json"))
- env.seed(workerseed)
- gym.logger.setLevel(logging.WARN)
- trpo_mpi.learn(env, policy_fn, timesteps_per_batch=15000, max_kl=0.01, cg_iters=500, cg_damping=0.1,
- max_timesteps=num_timesteps, gamma=0.995, lam=0.97, vf_iters=5, vf_stepsize=1e-3)
- env.close()
- def main():
- train('InvertedPendulum-v1', num_timesteps=1e6, seed=0)
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement