Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def _loss_fn(params):
- _, q_vals = jax.vmap(network.apply, in_axes=(None, 0, 0, 0))(
- params["agent"],
- init_hs,
- _obs,
- _dones,
- ) # (num_agents, timesteps, batch_size, num_actions)
- # get logits of the chosen actions
- chosen_action_q_vals = jnp.take_along_axis(
- q_vals,
- _actions[..., np.newaxis],
- axis=-1,
- ).squeeze(-1) # (num_agents, timesteps, batch_size,)
- unavailable_actions = 1 - _avail_actions
- valid_q_vals = q_vals - (unavailable_actions * 1e10)
- # get the q values of the next state
- q_next = jnp.take_along_axis(
- q_next_target,
- jnp.argmax(valid_q_vals, axis=-1)[..., np.newaxis],
- axis=-1,
- ).squeeze(-1) # (num_agents, timesteps, batch_size,)
- qmix_next = mixer.apply(
- train_state.target_network_params["mixer"],
- q_next,
- minibatch.obs["__all__"],
- )
- qmix_target = (
- minibatch.rewards["__all__"][:-1]
- + (
- 1 - minibatch.dones["__all__"][:-1]
- ) # use next done because last done was saved for rnn re-init
- * config["GAMMA"]
- * qmix_next[1:] # sum over agents
- )
- qmix = mixer.apply(
- params["mixer"], chosen_action_q_vals, minibatch.obs["__all__"]
- )[:-1]
- loss = jnp.mean((qmix - jax.lax.stop_gradient(qmix_target)) ** 2)
- return loss, chosen_action_q_vals.mean()
Advertisement
Add Comment
Please, Sign In to add comment