Advertisement
Guest User

Untitled

a guest
Apr 13th, 2021
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.53 KB | None | 0 0
  1. from stable_baselines3 import PPO
  2. from stable_baselines3.common.callbacks import CheckpointCallback
  3. #from pettingzoo.butterfly import pistonball_v4
  4. import supersuit as ss
  5. from ray import tune
  6. from ray.tune.suggest.optuna import OptunaSearch
  7. import optuna
  8. import os
  9. import ray
  10. from pathlib import Path
  11. import gym
  12. from ray.tune.suggest import ConcurrencyLimiter
  13.  
  14. space = {
  15.     "ent_coef": optuna.distributions.LogUniformDistribution(.001, .1),
  16. }
  17.  
  18.  
  19. optuna_search = OptunaSearch(
  20.     space,
  21.     metric="mean_reward",
  22.     mode="max")
  23.  
  24.  
  25. def make_env(n_envs):
  26.     if n_envs is None:
  27.         #env = pistonball_v4.env(time_penalty=-1)
  28.         env = gym.make('LunarLanderContinuous-v2')
  29.     else:
  30.         #env = pistonball_v4.parallel_env(time_penalty=-1)
  31.         env = gym.make('LunarLanderContinuous-v2')
  32.         env = ss.stable_baselines3_vec_env_v0(env, n_envs, multiprocessing=False)
  33.  
  34.     # env = ss.color_reduction_v0(env, mode='B')
  35.     # env = ss.resize_v0(env, x_size=84, y_size=84)
  36.     # env = ss.frame_stack_v1(env, 3)
  37.     # if n_envs is not None:
  38.     #     env = ss.pettingzoo_env_to_vec_env_v0(env)
  39.     #     env = ss.concat_vec_envs_v0(env, 2*n_envs, num_cpus=4, base_class='stable_baselines')
  40.  
  41.     return env
  42.  
  43.  
  44. def evaluate_all_policies(name):
  45.  
  46.     def evaluate_policy(env, model):
  47.         total_reward = 0
  48.         NUM_RESETS = 100
  49.         """
  50.        for i in range(NUM_RESETS):
  51.            env.reset()
  52.            for agent in env.agent_iter():
  53.                obs, reward, done, info = env.last()
  54.                total_reward += reward
  55.                act = model.predict(obs, deterministic=True)[0] if not done else None
  56.                env.step(act)
  57.            """
  58.         for i in range(NUM_RESETS):
  59.             done = False
  60.             obs = env.reset()
  61.             while not done:
  62.                 act = model.predict(obs, deterministic=True)[0] if not done else None
  63.                 observation, reward, done, info = env.step(act)
  64.                 total_reward += reward
  65.  
  66.         return total_reward/NUM_RESETS
  67.  
  68.     env = make_env(None)
  69.     policy_folder = str(Path.home())+'/policy_logs/'+name+'/'
  70.     policy_files = os.listdir(policy_folder)
  71.     policy_file = sorted(policy_files, key=lambda x: int(x[9:-10]))[-1]
  72.     model = PPO.load(policy_folder+policy_file)
  73.  
  74.     return evaluate_policy(env, model)
  75.  
  76.  
  77. def gen_filename(params):
  78.     name = ''
  79.     keys = list(params.keys())
  80.  
  81.     for key in keys:
  82.         name = name+key+'_'+str(params[key])[0:5]+'_'
  83.  
  84.     name = name[0:-1]  # removes trailing _
  85.     return name.replace('.', '')
  86.  
  87.  
  88. def train(parameterization):
  89.     name = gen_filename(parameterization)
  90.     folder = str(Path.home())+'/policy_logs/'+name+'/'
  91.     checkpoint_callback = CheckpointCallback(save_freq=400, save_path=folder)  # off by factor that I don't understand
  92.  
  93.     env = make_env(8)
  94.     # try:
  95.     model = PPO("MlpPolicy", env, gamma=.99, n_steps=1024, ent_coef=parameterization['ent_coef'], batch_size=128, tensorboard_log=(str(Path.home())+'/tensorboard_logs/'+name+'/'), policy_kwargs={"net_arch": [256, 256]})
  96.     model.learn(total_timesteps=3000000, callback=checkpoint_callback)  # time steps steps of each agent; was 4 million
  97.  
  98.     mean_reward = evaluate_all_policies(name)
  99.     # except:
  100.     #     mean_reward = -250
  101.     tune.report(mean_reward=mean_reward)
  102.  
  103.  
  104. ray.init(address='auto')
  105.  
  106. analysis = tune.run(
  107.     train,
  108.     num_samples=100,
  109.     search_alg=ConcurrencyLimiter(optuna_search, max_concurrent=10),
  110.     verbose=2,
  111.     resources_per_trial={"gpu": 1, "cpu": 5},
  112. )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement