Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class AlphaLoss(torch.nn.Module):
- def __init__(self):
- super(AlphaLoss, self).__init__()
- def forward(self, y_value, value, y_policy, policy):
- value_error = (value - y_value) ** 2
- policy_error = torch.sum((-policy*
- (1e-8 + y_policy.float()).float().log()), 1)
- total_error = (value_error.view(-1).float() + policy_error).mean()
- return total_error
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement