Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- share_parameters_policy = True
- policy_net = torch.nn.Sequential(
- MultiAgentMLP(
- n_agent_inputs=env.observation_spec["team0", "observation"].shape[
- -1
- ], # n_obs_per_agent
- n_agent_outputs=2 * 4, # 2 * n_actions_per_agents
- n_agents=3,
- centralised=False, # the policies are decentralised (ie each agent will act from its observation)
- share_params=share_parameters_policy,
- device=device,
- depth=2,
- num_cells=256,
- activation_class=torch.nn.Tanh,
- ),
- NormalParamExtractor(), # this will just separate the last dimension into two outputs: a loc and a non-negative scale
- )
- policy_module = TensorDictModule(
- policy_net,
- in_keys=[("team0", "observation")],
- out_keys=[("team0", "loc"), ("team0", "scale")],
- )
- share_parameters_policy = True
- policy_net = torch.nn.Sequential(
- MultiAgentMLP(
- n_agent_inputs=env.observation_spec["team0"].shape[-1], # n_obs_per_agent
- n_agent_outputs=2 * env.action_spec["team0"].shape[-1], # 2 * n_actions_per_agents
- n_agents=3,
- centralised=False, # the policies are decentralised (ie each agent will act from its observation)
- share_params=share_parameters_policy,
- device=device,
- depth=2,
- num_cells=256,
- activation_class=torch.nn.Tanh,
- ),
- NormalParamExtractor(), # this will just separate the last dimension into two outputs: a loc and a non-negative scale
- )
- policy_module = TensorDictModule(
- policy_net,
- in_keys=[("team0", "observation")],
- out_keys=[("team0", "loc"), ("team0", "scale")],
- )
- policy = ProbabilisticActor(
- module=policy_module,
- #spec=env.unbatched_action_spec,
- in_keys=[("team0", "loc"), ("team0", "scale")],
- spec = env.action_spec,
- out_keys=[env.action_keys],
- distribution_class=TanhNormal,
- distribution_kwargs={
- "min": env.action_spec["team0"][env.action_keys].space.low,
- "max": env.action_spec["team0"][env.action_keys].space.high,
- },
- return_log_prob=True,
- log_prob_key=("team0", "sample_log_prob"),
- )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement