Advertisement
Guest User

Untitled

a guest
Apr 25th, 2019
68
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.42 KB | None | 0 0
  1. class AlphaLoss(torch.nn.Module):
  2. def __init__(self):
  3. super(AlphaLoss, self).__init__()
  4.  
  5. def forward(self, y_value, value, y_policy, policy):
  6. value_error = (value - y_value) ** 2
  7. policy_error = torch.sum((-policy*
  8. (1e-8 + y_policy.float()).float().log()), 1)
  9. total_error = (value_error.view(-1).float() + policy_error).mean()
  10. return total_error
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement