Advertisement
Guest User

Untitled

a guest
Apr 20th, 2019
124
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.71 KB | None | 0 0
  1. def get_accuracy(p, threshold):
  2. net = Net()
  3. net.load_state_dict(torch.load('./models/p=' + str(p) + '_t=' + str(threshold)))
  4. net.eval()
  5. total = 0
  6. second_total = 0
  7. class_correct = list(0. for i in range(10))
  8. class_total = list(0. for i in range(10))
  9. with torch.no_grad():
  10. for data in testloader:
  11. images, labels = data
  12. outputs = net(images, is_trained = True)
  13. _, predicted = torch.max(outputs, 1)
  14. c = (predicted == labels).squeeze()
  15. label = labels[0]
  16. class_correct[label] += c.item()
  17. class_total[label] += 1
  18. accuracy = sum(class_correct)/sum(class_total) * 100
  19. return accuracy
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement