Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- nb_classes = 9
- confusion_matrix = torch.zeros(nb_classes, nb_classes)
- cat1_corr = 0
- cat1 = 0
- cat2_corr = 0
- cat2 = 0
- cat3_corr = 0
- cat3 = 0
- cat4_corr = 0
- cat4 = 0
- cat5_corr = 0
- cat5 = 0
- cat6_corr = 0
- cat6 = 0
- cat7_corr = 0
- cat7 = 0
- cat8_corr = 0
- cat8 = 0
- cat9_corr = 0
- cat9 = 0
- _classes = []
- predicted_labels = []
- with torch.no_grad():
- for i, (inputs, classes) in enumerate(dataloaders['test']):
- inputs = inputs.to(device)
- tmp_labels = model_ft(inputs)
- classes = classes.to(device)
- classes_list = classes.cpu().detach().numpy().tolist()
- print('classes list: ', classes_list)
- _classes[:]=[i+1 for i in classes_list]
- print('_classes: ', _classes)
- preds = preds.cpu().detach().numpy().tolist()
- outputs = model_ft(inputs)
- _, preds = torch.max(outputs, 1)
- preds_list = preds.cpu().detach().numpy().tolist()
- _preds[:]=[i+1 for i in preds_list]
- for i in range(4):
- if _classes[i] == 1:
- cat1 += 1
- if _classes[i] == _preds[i]:
- cat1_corr += 1
- elif _classes[i] == 2:
- cat2 += 1
- if _classes[i] == _preds[i]:
- cat2_corr += 1
- elif _classes[i] == 3:
- cat3 += 1
- if _classes[i] == _preds[i]:
- cat3_corr += 1
- elif _classes[i] == 4:
- cat4 += 1
- if _classes[i] == _preds[i]:
- cat4_corr += 1
- elif _classes[i] == 5:
- cat5 += 1
- if _classes[i] == _preds[i]:
- cat5_corr += 1
- elif _classes[i] == 6:
- cat6 += 1
- if _classes[i] == _preds[i]:
- cat6_corr += 1
- elif _classes[i] == 7:
- cat7 += 1
- if _classes[i] == _preds[i]:
- cat7_corr += 1
- elif _classes[i] == 8:
- cat8 += 1
- if _classes[i] == _preds[i]:
- cat8_corr += 1
- elif _classes[i] == 9:
- print('here')
- cat9 += 1
- if _classes[i] == _preds[i]:
- cat9_corr += 1
- predicted_labels.append(preds.cpu().detach().numpy().tolist())
- for t, p in zip(classes.view(-1), preds.view(-1)):
- confusion_matrix[t.long(), p.long()] += 1
- print(confusion_matrix)
- cat1_acc = cat1_corr/cat1
- cat2_acc = cat2_corr/cat2
- cat3_acc = cat3_corr/cat3
- cat4_acc = cat4_corr/cat4
- cat5_acc = cat5_corr/cat5
- cat6_acc = cat6_corr/cat6
- cat7_acc = cat7_corr/cat7
- cat8_acc = cat8_corr/cat8
- cat9_acc = cat9_corr/cat9
- print('1', cat1_acc)
- print('2', cat2_acc)
- print('3', cat3_acc)
- print('4', cat4_acc)
- print('5', cat5_acc)
- print('6', cat6_acc)
- print('7', cat7_acc)
- print('8', cat8_acc)
- print('9', cat9_acc)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement