Guest User

PettingZoo + RLlib requires same observation spaces for everyone!

a guest
Jul 27th, 2023
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.57 KB | Source Code | 0 0
  1.  
  2. from gymnasium.spaces import Discrete
  3. from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, MultiBinary
  4. from pettingzoo.utils.env import ParallelEnv
  5. from pettingzoo.test import parallel_api_test
  6. import ray
  7. from ray.tune.registry import register_env
  8. from ray.rllib.policy.policy import PolicySpec
  9. from ray.rllib.examples.policy.random_policy import RandomPolicy
  10. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  11. from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
  12. from ray.tune.logger import pretty_print
  13.  
  14.  
  15. class DifferentSpacesEnv(ParallelEnv):
  16.     """
  17.    PettingZoo environment that implements different
  18.    observation and action spaces.
  19.    """
  20.  
  21.     def __init__(self, **config):
  22.         """
  23.        Initialises the class. The logic is not important,
  24.        just that the spaces are different.
  25.        """
  26.         self.agents = ["agent_1", "agent_2"]
  27.         self.possible_agents = self.agents.copy()
  28.  
  29.         self.observation_spaces = {
  30.             "agent_1": Dict({
  31.                 "a": Discrete(3),
  32.                 "b": Dict({
  33.                     "e": Box(1, 3)
  34.                 })
  35.              }
  36.              ),
  37.             "agent_2": Dict({
  38.                 "dis": Discrete(30),
  39.                 "dico": MultiDiscrete([11,11])
  40.             }
  41.             ),
  42.         }
  43.         self.action_spaces = {
  44.             "agent_1": Discrete(4),
  45.             "agent_2": Dict({"e": Box(1, 3)}),
  46.         }
  47.         self.max_time = 1000
  48.         self.curr_time = 0
  49.  
  50.     def step(self, actions: dict):
  51.         """
  52.        Takes a step and returns the observation,
  53.        rewards, terminated, truncated and info dictionaries,
  54.        each keyed by agent.
  55.        """
  56.         self.curr_time += 1
  57.         rewards = {agent: 1 for agent in self.agents}
  58.         obs = {
  59.             agent: self.observation_spaces[agent].sample()
  60.             for agent in self.agents
  61.         }
  62.         terminated = dict.fromkeys(self.agents, False)
  63.         truncated = dict.fromkeys(self.agents, False)
  64.         info = dict.fromkeys(self.agents, False)
  65.  
  66.         return obs, rewards, terminated, truncated, info
  67.  
  68.     def reset(self, seed: int = None, options=None):
  69.         """
  70.        Calls the super reset
  71.        """
  72.         self.agents = self.possible_agents[:]
  73.         for agent in self.agents:
  74.             self.observation_spaces[agent].seed(seed)
  75.             self.action_spaces[agent].seed(seed)
  76.  
  77.         return dict.fromkeys(self.agents, 0)
  78.  
  79.     def render(self):
  80.         # render the game
  81.         return "|".join(self.agents)
  82.  
  83.     def state(self):
  84.         """
  85.        This is a stateless game, therefore, we return None
  86.        """
  87.         return None
  88.  
  89.     def observation_space(self, agent: str):
  90.         return self.observation_spaces[agent]
  91.  
  92.     def action_space(self, agent: str):
  93.         return self.action_spaces[agent]
  94.  
  95.  
  96. if __name__ == "__main__":
  97.     # assert that this is a ParallelEnv
  98.     # This will throw if not.
  99.     parallel_api_test(DifferentSpacesEnv())
  100.  
  101.     # Convert this into a MultiAgentEnv
  102.     ray.init()
  103.     game = DifferentSpacesEnv()
  104.     register_env(
  105.         "DifferentSpacesEnv",
  106.         lambda config: ParallelPettingZooEnv(DifferentSpacesEnv(**config)),
  107.     )
  108.  
  109.     # Run some sample algorithm on it
  110.     config = (
  111.         DQNConfig()
  112.         .environment("DifferentSpacesEnv")
  113.         .training(train_batch_size=1024)
  114.         .framework("torch")
  115.         .multi_agent(
  116.             policies={
  117.                 "agent_1": PolicySpec(
  118.                     policy_class=None,  # Infer automatically
  119.                     observation_space=game.observation_space("agent_1"),
  120.                     action_space=game.action_space("agent_1"),
  121.                     config={},
  122.                 ),
  123.                 "agent_2": PolicySpec(
  124.                     policy_class=RandomPolicy,
  125.                     observation_space=game.observation_space("agent_2"),
  126.                     action_space=game.action_space("agent_2"),
  127.                     config={},
  128.                 ),
  129.             },  # end-policies
  130.             policy_mapping_fn=lambda ag_id, *args, **kwargs: ag_id,
  131.             policies_to_train=[
  132.                 "agent_1",
  133.                 "agent_2",
  134.             ],  # train all policies
  135.         )  # end-multi-agent
  136.     )  # end-config
  137.  
  138.     algo = config.build()
  139.     result = list()
  140.     for i in range(10):
  141.         result.append(algo.train())
  142.  
  143.     trainer.stop()
  144.     ray.shutdown()
  145.     for i in range(10):
  146.         print(f"Result at iteration {i}:")
  147.         print(pretty_print(result[i]))
  148.  
  149.     # close
  150.     ray.close()
Advertisement
Add Comment
Please, Sign In to add comment