Advertisement
Guest User

Untitled

a guest
Apr 13th, 2021
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.80 KB | None | 0 0
  1. from stable_baselines3 import PPO
  2. from stable_baselines3.common.callbacks import CheckpointCallback
  3. #from pettingzoo.butterfly import pistonball_v4
  4. import supersuit as ss
  5. from ray import tune
  6. from ray.tune.suggest.optuna import OptunaSearch
  7. import optuna
  8. import os
  9. import ray
  10. from pathlib import Path
  11. import gym
  12. from ray.tune.suggest import ConcurrencyLimiter
  13.  
  14. space = {
  15.     "ent_coef": optuna.distributions.LogUniformDistribution(.001, .1),
  16.     "learning_rate": optuna.distributions.LogUniformDistribution(5e-6, 5e-4),
  17.     "gae_lambda": optuna.distributions.CategoricalDistribution([0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]),
  18. }
  19.  
  20. optuna_search = OptunaSearch(
  21.     space,
  22.     metric="mean_reward",
  23.     mode="max")
  24.  
  25.  
  26. def make_env(n_envs):
  27.     if n_envs is None:
  28.         #env = pistonball_v4.env(time_penalty=-1)
  29.         env = gym.make('LunarLanderContinuous-v2')
  30.     else:
  31.         #env = pistonball_v4.parallel_env(time_penalty=-1)
  32.         env = gym.make('LunarLanderContinuous-v2')
  33.         env = ss.stable_baselines3_vec_env_v0(env, n_envs, multiprocessing=False)
  34.  
  35.     # env = ss.color_reduction_v0(env, mode='B')
  36.     # env = ss.resize_v0(env, x_size=84, y_size=84)
  37.     # env = ss.frame_stack_v1(env, 3)
  38.     # if n_envs is not None:
  39.     #     env = ss.pettingzoo_env_to_vec_env_v0(env)
  40.     #     env = ss.concat_vec_envs_v0(env, 2*n_envs, num_cpus=4, base_class='stable_baselines')
  41.  
  42.     return env
  43.  
  44.  
  45. def evaluate_all_policies(name):
  46.  
  47.     def evaluate_policy(env, model):
  48.         total_reward = 0
  49.         NUM_RESETS = 100
  50.         """
  51.        for i in range(NUM_RESETS):
  52.            env.reset()
  53.            for agent in env.agent_iter():
  54.                obs, reward, done, info = env.last()
  55.                total_reward += reward
  56.                act = model.predict(obs, deterministic=True)[0] if not done else None
  57.                env.step(act)
  58.            """
  59.         for i in range(NUM_RESETS):
  60.             done = False
  61.             obs = env.reset()
  62.             while not done:
  63.                 act = model.predict(obs, deterministic=True)[0] if not done else None
  64.                 observation, reward, done, info = env.step(act)
  65.                 total_reward += reward
  66.  
  67.         return total_reward/NUM_RESETS
  68.  
  69.     env = make_env(None)
  70.     policy_folder = str(Path.home())+'/policy_logs/'+name+'/'
  71.     policy_files = os.listdir(policy_folder)
  72.     policy_file = sorted(policy_files, key=lambda x: int(x[9:-10]))[-1]
  73.     model = PPO.load(policy_folder+policy_file)
  74.  
  75.     return evaluate_policy(env, model)
  76.  
  77.  
  78. def gen_filename(params):
  79.     name = ''
  80.     keys = list(params.keys())
  81.  
  82.     for key in keys:
  83.         name = name+key+'_'+str(params[key])[0:5]+'_'
  84.  
  85.     name = name[0:-1]  # removes trailing _
  86.     return name.replace('.', '')
  87.  
  88.  
  89. def train(parameterization):
  90.     name = gen_filename(parameterization)
  91.     folder = str(Path.home())+'/policy_logs/'+name+'/'
  92.     checkpoint_callback = CheckpointCallback(save_freq=400, save_path=folder)  # off by factor that I don't understand
  93.  
  94.     env = make_env(8)
  95.     # try:
  96.     model = PPO("MlpPolicy", env, gamma=.99, n_steps=1024, ent_coef=parameterization['ent_coef'], learning_rate=parameterization['learning_rate'], gae_lambda=parameterization['gae_lambda'], batch_size=128, tensorboard_log=(str(Path.home())+'/tensorboard_logs/'+name+'/'), policy_kwargs={"net_arch": [256, 256]})
  97.     model.learn(total_timesteps=2000000, callback=checkpoint_callback)  # time steps steps of each agent; was 4 million
  98.  
  99.     mean_reward = evaluate_all_policies(name)
  100.     # except:
  101.     #     mean_reward = -250
  102.     tune.report(mean_reward=mean_reward)
  103.  
  104.  
  105. ray.init(address='auto')
  106.  
  107. analysis = tune.run(
  108.     train,
  109.     num_samples=100,
  110.     search_alg=ConcurrencyLimiter(optuna_search, max_concurrent=10),
  111.     verbose=2,
  112.     resources_per_trial={"gpu": 1, "cpu": 5},
  113. )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement