Guest User

Untitled

a guest
Apr 6th, 2021
279
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from stable_baselines3 import PPO
  2. from stable_baselines3.common.callbacks import CheckpointCallback
  3. #from pettingzoo.butterfly import pistonball_v4
  4. import supersuit as ss
  5. from ray import tune
  6. from ray.tune.suggest.optuna import OptunaSearch
  7. import optuna
  8. import os
  9. import ray
  10. from pathlib import Path
  11. import gym
  12. from ray.tune.suggest import ConcurrencyLimiter
  13.  
  14. space = {
  15.     "n_epochs": optuna.distributions.IntUniformDistribution(3, 50),
  16.     "gamma": optuna.distributions.LogUniformDistribution(.9, .999),
  17.     "ent_coef": optuna.distributions.LogUniformDistribution(.001, .1),
  18.     "learning_rate": optuna.distributions.LogUniformDistribution(5e-6, 5e-4),
  19.     "vf_coef": optuna.distributions.UniformDistribution(.1, 1),
  20.     "gae_lambda": optuna.distributions.UniformDistribution(.8, 1),
  21.     "max_grad_norm": optuna.distributions.LogUniformDistribution(.01, 10),
  22.     "n_steps": optuna.distributions.CategoricalDistribution([128, 256, 512, 1024, 2048, 4096]),
  23.     "batch_size": optuna.distributions.CategoricalDistribution([32, 64, 128, 256]),  # , 512, 1024, 2048, 4096
  24.     "n_envs": optuna.distributions.CategoricalDistribution([2, 4, 8]),
  25.     "clip_range": optuna.distributions.UniformDistribution(.1, 5),
  26. }
  27.  
  28.  
  29. optuna_search = OptunaSearch(
  30.     space,
  31.     metric="mean_reward",
  32.     mode="max")
  33.  
  34.  
  35. def make_env(n_envs):
  36.     if n_envs is None:
  37.         env = gym.make('LunarLanderContinuous-v2')
  38.     else:
  39.         env = gym.make('LunarLanderContinuous-v2')
  40.         env = ss.stable_baselines3_vec_env_v0(env, n_envs, multiprocessing=False)
  41.     return env
  42.  
  43.  
  44. def evaluate_all_policies(name):
  45.  
  46.     def evaluate_policy(env, model):
  47.         total_reward = 0
  48.         NUM_RESETS = 100
  49.  
  50.         for i in range(NUM_RESETS):
  51.             done = False
  52.             obs = env.reset()
  53.             while not done:
  54.                 act = model.predict(obs, deterministic=True)[0] if not done else None
  55.                 observation, reward, done, info = env.step(act)
  56.                 total_reward += reward
  57.  
  58.         return total_reward/NUM_RESETS
  59.  
  60.     env = make_env(None)
  61.     policy_folder = str(Path.home())+'/policy_logs/'+name+'/'
  62.     policy_files = os.listdir(policy_folder)
  63.     policy_file = sorted(policy_files, key=lambda x: int(x[9:-10]))[-1]
  64.     model = PPO.load(policy_folder+policy_file)
  65.  
  66.     return evaluate_policy(env, model)
  67.  
  68.  
  69. def gen_filename(params):
  70.     name = ''
  71.     keys = list(params.keys())
  72.  
  73.     for key in keys:
  74.         name = name+key+'_'+str(params[key])[0:5]+'_'
  75.  
  76.     name = name[0:-1]  # removes trailing _
  77.     return name.replace('.', '')
  78.  
  79.  
  80. def train(parameterization):
  81.     name = gen_filename(parameterization)
  82.     folder = str(Path.home())+'/policy_logs/'+name+'/'
  83.     checkpoint_callback = CheckpointCallback(save_freq=400, save_path=folder)
  84.  
  85.     env = make_env(parameterization['n_envs'])
  86.     model = PPO("MlpPolicy", env, gamma=parameterization['gamma'], n_steps=parameterization['n_steps'], ent_coef=parameterization['ent_coef'], learning_rate=parameterization['learning_rate'], vf_coef=parameterization['vf_coef'], max_grad_norm=parameterization['max_grad_norm'], gae_lambda=parameterization['gae_lambda'], batch_size=parameterization['batch_size'], clip_range=parameterization['clip_range'], n_epochs=parameterization['n_epochs'], tensorboard_log=(str(Path.home())+'/tensorboard_logs/'+name+'/'), policy_kwargs={"net_arch": [256, 256]})
  87.     model.learn(total_timesteps=2000000, callback=checkpoint_callback)  # time steps steps of each agent; was 4 million
  88.  
  89.     mean_reward = evaluate_all_policies(name)
  90.     tune.report(mean_reward=mean_reward)
  91.  
  92.  
  93. ray.init(address='auto')
  94.  
  95. analysis = tune.run(
  96.     train,
  97.     num_samples=100,
  98.     search_alg=ConcurrencyLimiter(optuna_search, max_concurrent=10),
  99.     verbose=2,
  100.     resources_per_trial={"gpu": 1, "cpu": 5},
  101. )
RAW Paste Data