Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def get_accuracy(p, threshold):
- net = Net()
- net.load_state_dict(torch.load('./models/p=' + str(p) + '_t=' + str(threshold)))
- net.eval()
- total = 0
- second_total = 0
- class_correct = list(0. for i in range(10))
- class_total = list(0. for i in range(10))
- with torch.no_grad():
- for data in testloader:
- images, labels = data
- outputs = net(images, is_trained = True)
- _, predicted = torch.max(outputs, 1)
- c = (predicted == labels).squeeze()
- label = labels[0]
- class_correct[label] += c.item()
- class_total[label] += 1
- accuracy = sum(class_correct)/sum(class_total) * 100
- return accuracy
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement