Guest User

Untitled

a guest
Sep 8th, 2024
39
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.00 KB | None | 0 0
  1. def _loss_fn(params):
  2.                     _, q_vals = jax.vmap(network.apply, in_axes=(None, 0, 0, 0))(
  3.                         params["agent"],
  4.                         init_hs,
  5.                         _obs,
  6.                         _dones,
  7.                     )  # (num_agents, timesteps, batch_size, num_actions)
  8.  
  9.                     # get logits of the chosen actions
  10.                     chosen_action_q_vals = jnp.take_along_axis(
  11.                         q_vals,
  12.                         _actions[..., np.newaxis],
  13.                         axis=-1,
  14.                     ).squeeze(-1)  # (num_agents, timesteps, batch_size,)
  15.  
  16.                     unavailable_actions = 1 - _avail_actions
  17.                     valid_q_vals = q_vals - (unavailable_actions * 1e10)
  18.  
  19.                     # get the q values of the next state
  20.                     q_next = jnp.take_along_axis(
  21.                         q_next_target,
  22.                         jnp.argmax(valid_q_vals, axis=-1)[..., np.newaxis],
  23.                         axis=-1,
  24.                     ).squeeze(-1)  # (num_agents, timesteps, batch_size,)
  25.  
  26.                     qmix_next = mixer.apply(
  27.                         train_state.target_network_params["mixer"],
  28.                         q_next,
  29.                         minibatch.obs["__all__"],
  30.                     )
  31.                     qmix_target = (
  32.                         minibatch.rewards["__all__"][:-1]
  33.                         + (
  34.                             1 - minibatch.dones["__all__"][:-1]
  35.                         )  # use next done because last done was saved for rnn re-init
  36.                         * config["GAMMA"]
  37.                         * qmix_next[1:]  # sum over agents
  38.                     )
  39.  
  40.                     qmix = mixer.apply(
  41.                         params["mixer"], chosen_action_q_vals, minibatch.obs["__all__"]
  42.                     )[:-1]
  43.                     loss = jnp.mean((qmix - jax.lax.stop_gradient(qmix_target)) ** 2)
  44.  
  45.                     return loss, chosen_action_q_vals.mean()
Advertisement
Add Comment
Please, Sign In to add comment