Advertisement
fierydragon789

Creating the model

Mar 17th, 2024 (edited)
41
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.12 KB | None | 0 0
  1. share_parameters_policy = True
  2. policy_net = torch.nn.Sequential(
  3. MultiAgentMLP(
  4. n_agent_inputs=env.observation_spec["team0", "observation"].shape[
  5. -1
  6. ], # n_obs_per_agent
  7. n_agent_outputs=2 * 4, # 2 * n_actions_per_agents
  8. n_agents=3,
  9. centralised=False, # the policies are decentralised (ie each agent will act from its observation)
  10. share_params=share_parameters_policy,
  11. device=device,
  12. depth=2,
  13. num_cells=256,
  14. activation_class=torch.nn.Tanh,
  15. ),
  16. NormalParamExtractor(), # this will just separate the last dimension into two outputs: a loc and a non-negative scale
  17. )
  18.  
  19. policy_module = TensorDictModule(
  20. policy_net,
  21. in_keys=[("team0", "observation")],
  22. out_keys=[("team0", "loc"), ("team0", "scale")],
  23. )
  24.  
  25. share_parameters_policy = True
  26. policy_net = torch.nn.Sequential(
  27. MultiAgentMLP(
  28. n_agent_inputs=env.observation_spec["team0"].shape[-1], # n_obs_per_agent
  29. n_agent_outputs=2 * env.action_spec["team0"].shape[-1], # 2 * n_actions_per_agents
  30. n_agents=3,
  31. centralised=False, # the policies are decentralised (ie each agent will act from its observation)
  32. share_params=share_parameters_policy,
  33. device=device,
  34. depth=2,
  35. num_cells=256,
  36. activation_class=torch.nn.Tanh,
  37. ),
  38. NormalParamExtractor(), # this will just separate the last dimension into two outputs: a loc and a non-negative scale
  39. )
  40.  
  41. policy_module = TensorDictModule(
  42. policy_net,
  43. in_keys=[("team0", "observation")],
  44. out_keys=[("team0", "loc"), ("team0", "scale")],
  45. )
  46.  
  47. policy = ProbabilisticActor(
  48. module=policy_module,
  49. #spec=env.unbatched_action_spec,
  50. in_keys=[("team0", "loc"), ("team0", "scale")],
  51. spec = env.action_spec,
  52. out_keys=[env.action_keys],
  53. distribution_class=TanhNormal,
  54. distribution_kwargs={
  55. "min": env.action_spec["team0"][env.action_keys].space.low,
  56. "max": env.action_spec["team0"][env.action_keys].space.high,
  57. },
  58. return_log_prob=True,
  59. log_prob_key=("team0", "sample_log_prob"),
  60. )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement