Advertisement
Guest User

Untitled

a guest
Sep 21st, 2017
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.58 KB | None | 0 0
  1. def accuracy_2d(output, target, topk=(1,)):
  2. """
  3. Computes the precision@k for the specified values of k
  4.  
  5. Considers output is : NxCxHxW and target is : NxHxW
  6. """
  7. maxk = max(topk)
  8. batch_size = target.size(0)
  9. total_nelem = batch_size*target.size(-1)*target.size(-2)
  10.  
  11. _, pred = output.topk(maxk, 1, True, True)
  12. correct = target.unsqueeze(1).expand(pred.size())
  13. correct = pred.eq(correct)
  14.  
  15. res = []
  16. for k in topk:
  17. correct_k = correct[:, :k].contiguous().view(-1).float().sum(0)
  18. res.append(correct_k.mul_(100.0 / total_nelem))
  19. return res
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement