Advertisement
Sitting-Down

Experiment_DQN

Jul 21st, 2022 (edited)
37
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.28 KB | None | 0 0
  1. import os
  2. import itertools
  3. import sys
  4. import ray
  5. import sumo_rl
  6. import supersuit
  7.  
  8. if 'SUMO_HOME' in os.environ:
  9. tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
  10. sys.path.append(tools)
  11. else:
  12. sys.exit("Please declare the environment variable 'SUMO_HOME'")
  13.  
  14. from ray.rllib.agents.dqn.dqn import DQNTrainer
  15. from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
  16. from ray.rllib.env import ParallelPettingZooEnv
  17. from ray.tune.registry import register_env
  18.  
  19. SAVE_PATH = 'D:/0-School/Experiment/checkpoints/c'
  20. LOAD_PATH = 'D:/0-School/Experiment/checkpoints/c/checkpoint_000026/checkpoint-26'
  21.  
  22. def create_env(env_config):
  23. env = sumo_rl.parallel_env(
  24. route_file='D:/0-School/Experiment/Simulation/routes/Day1AM.rou.xml',
  25. net_file='D:/0-School/Experiment/Simulation/nets/main_02.net.xml',
  26. out_csv_name='D:/0-School/Experiment/outputs/Day1/AM/QL-Learned/dqn',
  27. single_agent=False,
  28. use_gui=False,
  29. fixed_ts=False,
  30. num_seconds=14400,
  31. min_green=5,
  32. max_green=60,
  33. yellow_time=4,
  34. delta_time=5,
  35. sumo_warnings=False,
  36. reward_fn='diff-waiting-time'
  37. )
  38.  
  39. env = supersuit.pad_observations_v0(env)
  40. env = supersuit.pad_action_space_v0(env)
  41.  
  42. env = ParallelPettingZooEnv(env)
  43.  
  44. return env
  45.  
  46. if __name__ == '__main__':
  47.  
  48. ray.init()
  49. env = create_env('')
  50.  
  51. padded_action_space = env.action_space
  52. padded_observation_space = env.observation_space
  53.  
  54. register_env("dqn_3_intersections", lambda config: create_env(config))
  55.  
  56. trainer = DQNTrainer(env="dqn_3_intersections", config={
  57. "multiagent": {
  58. "policies": {
  59. '0': (DQNTFPolicy, padded_observation_space, padded_action_space, {}),
  60. },
  61. "policy_mapping_fn": (lambda _: '0' )
  62. },
  63. "lr": 0.001,
  64. 'log_level': 'INFO',
  65. 'num_workers': 2,
  66. 'num_gpus': 1,
  67. "no_done_at_end": True
  68. })
  69.  
  70. trainer.restore(LOAD_PATH)
  71. print('loaded')
  72. for step in itertools.count():
  73. trainer.train() # distributed training step
  74. if (step != 0) and (step % 10):
  75. print('save ne')
  76. print(step)
  77. trainer.save(SAVE_PATH)
  78.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement