Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from gymnasium.spaces import Discrete
- from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, MultiBinary
- from pettingzoo.utils.env import ParallelEnv
- from pettingzoo.test import parallel_api_test
- import ray
- from ray.tune.registry import register_env
- from ray.rllib.policy.policy import PolicySpec
- from ray.rllib.examples.policy.random_policy import RandomPolicy
- from ray.rllib.algorithms.dqn.dqn import DQNConfig
- from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
- from ray.tune.logger import pretty_print
- class DifferentSpacesEnv(ParallelEnv):
- """
- PettingZoo environment that implements different
- observation and action spaces.
- """
- def __init__(self, **config):
- """
- Initialises the class. The logic is not important,
- just that the spaces are different.
- """
- self.agents = ["agent_1", "agent_2"]
- self.possible_agents = self.agents.copy()
- self.observation_spaces = {
- "agent_1": Dict({
- "a": Discrete(3),
- "b": Dict({
- "e": Box(1, 3)
- })
- }
- ),
- "agent_2": Dict({
- "dis": Discrete(30),
- "dico": MultiDiscrete([11,11])
- }
- ),
- }
- self.action_spaces = {
- "agent_1": Discrete(4),
- "agent_2": Dict({"e": Box(1, 3)}),
- }
- self.max_time = 1000
- self.curr_time = 0
- def step(self, actions: dict):
- """
- Takes a step and returns the observation,
- rewards, terminated, truncated and info dictionaries,
- each keyed by agent.
- """
- self.curr_time += 1
- rewards = {agent: 1 for agent in self.agents}
- obs = {
- agent: self.observation_spaces[agent].sample()
- for agent in self.agents
- }
- terminated = dict.fromkeys(self.agents, False)
- truncated = dict.fromkeys(self.agents, False)
- info = dict.fromkeys(self.agents, False)
- return obs, rewards, terminated, truncated, info
- def reset(self, seed: int = None, options=None):
- """
- Calls the super reset
- """
- self.agents = self.possible_agents[:]
- for agent in self.agents:
- self.observation_spaces[agent].seed(seed)
- self.action_spaces[agent].seed(seed)
- return dict.fromkeys(self.agents, 0)
- def render(self):
- # render the game
- return "|".join(self.agents)
- def state(self):
- """
- This is a stateless game, therefore, we return None
- """
- return None
- def observation_space(self, agent: str):
- return self.observation_spaces[agent]
- def action_space(self, agent: str):
- return self.action_spaces[agent]
- if __name__ == "__main__":
- # assert that this is a ParallelEnv
- # This will throw if not.
- parallel_api_test(DifferentSpacesEnv())
- # Convert this into a MultiAgentEnv
- ray.init()
- game = DifferentSpacesEnv()
- register_env(
- "DifferentSpacesEnv",
- lambda config: ParallelPettingZooEnv(DifferentSpacesEnv(**config)),
- )
- # Run some sample algorithm on it
- config = (
- DQNConfig()
- .environment("DifferentSpacesEnv")
- .training(train_batch_size=1024)
- .framework("torch")
- .multi_agent(
- policies={
- "agent_1": PolicySpec(
- policy_class=None, # Infer automatically
- observation_space=game.observation_space("agent_1"),
- action_space=game.action_space("agent_1"),
- config={},
- ),
- "agent_2": PolicySpec(
- policy_class=RandomPolicy,
- observation_space=game.observation_space("agent_2"),
- action_space=game.action_space("agent_2"),
- config={},
- ),
- }, # end-policies
- policy_mapping_fn=lambda ag_id, *args, **kwargs: ag_id,
- policies_to_train=[
- "agent_1",
- "agent_2",
- ], # train all policies
- ) # end-multi-agent
- ) # end-config
- algo = config.build()
- result = list()
- for i in range(10):
- result.append(algo.train())
- trainer.stop()
- ray.shutdown()
- for i in range(10):
- print(f"Result at iteration {i}:")
- print(pretty_print(result[i]))
- # close
- ray.close()
Advertisement
Add Comment
Please, Sign In to add comment