Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import itertools
- import sys
- import ray
- import sumo_rl
- import supersuit
- if 'SUMO_HOME' in os.environ:
- tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
- sys.path.append(tools)
- else:
- sys.exit("Please declare the environment variable 'SUMO_HOME'")
- from ray.rllib.agents.dqn.dqn import DQNTrainer
- from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
- from ray.rllib.env import ParallelPettingZooEnv
- from ray.tune.registry import register_env
- SAVE_PATH = 'D:/0-School/Experiment/checkpoints/c'
- LOAD_PATH = 'D:/0-School/Experiment/checkpoints/c/checkpoint_000026/checkpoint-26'
- def create_env(env_config):
- env = sumo_rl.parallel_env(
- route_file='D:/0-School/Experiment/Simulation/routes/Day1AM.rou.xml',
- net_file='D:/0-School/Experiment/Simulation/nets/main_02.net.xml',
- out_csv_name='D:/0-School/Experiment/outputs/Day1/AM/QL-Learned/dqn',
- single_agent=False,
- use_gui=False,
- fixed_ts=False,
- num_seconds=14400,
- min_green=5,
- max_green=60,
- yellow_time=4,
- delta_time=5,
- sumo_warnings=False,
- reward_fn='diff-waiting-time'
- )
- env = supersuit.pad_observations_v0(env)
- env = supersuit.pad_action_space_v0(env)
- env = ParallelPettingZooEnv(env)
- return env
- if __name__ == '__main__':
- ray.init()
- env = create_env('')
- padded_action_space = env.action_space
- padded_observation_space = env.observation_space
- register_env("dqn_3_intersections", lambda config: create_env(config))
- trainer = DQNTrainer(env="dqn_3_intersections", config={
- "multiagent": {
- "policies": {
- '0': (DQNTFPolicy, padded_observation_space, padded_action_space, {}),
- },
- "policy_mapping_fn": (lambda _: '0' )
- },
- "lr": 0.001,
- 'log_level': 'INFO',
- 'num_workers': 2,
- 'num_gpus': 1,
- "no_done_at_end": True
- })
- trainer.restore(LOAD_PATH)
- print('loaded')
- for step in itertools.count():
- trainer.train() # distributed training step
- if (step != 0) and (step % 10):
- print('save ne')
- print(step)
- trainer.save(SAVE_PATH)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement