Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def calc_sil_policy_val_loss(self, batch):
- '''
- Calculate the SIL policy losses for actor and critic
- sil_policy_loss = -log_prob * max(R - v_pred, 0)
- sil_val_loss = norm(max(R - v_pred, 0)) / 2
- This is called on a randomly-sample batch from experience replay
- '''
- returns = math_util.calc_returns(batch, self.gamma)
- v_preds = self.calc_v(batch['states'])
- clipped_advs = torch.clamp(returns - v_preds, min=0.0)
- log_probs = self.calc_log_probs(batch)
- sil_policy_loss = torch.mean(- log_probs * v_preds)
- sil_val_loss = torch.norm(clipped_advs ** 2) / 2
- if torch.cuda.is_available() and self.net.gpu:
- sil_policy_loss = sil_policy_loss.cuda()
- sil_val_loss = sil_val_loss.cuda()
- return sil_policy_loss, sil_val_loss
Add Comment
Please, Sign In to add comment