Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def accuracy_2d(output, target, topk=(1,)):
- """
- Computes the precision@k for the specified values of k
- Considers output is : NxCxHxW and target is : NxHxW
- """
- maxk = max(topk)
- batch_size = target.size(0)
- total_nelem = batch_size*target.size(-1)*target.size(-2)
- _, pred = output.topk(maxk, 1, True, True)
- correct = target.unsqueeze(1).expand(pred.size())
- correct = pred.eq(correct)
- res = []
- for k in topk:
- correct_k = correct[:, :k].contiguous().view(-1).float().sum(0)
- res.append(correct_k.mul_(100.0 / total_nelem))
- return res
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement