Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- nb_classes = 9
- confusion_matrix = torch.zeros(nb_classes, nb_classes)
- cat1_corr = 0
- cat1 = 0
- cat2_corr = 0
- cat2 = 0
- cat3_corr = 0
- cat3 = 0
- cat4_corr = 0
- cat4 = 0
- cat5_corr = 0
- cat5 = 0
- cat6_corr = 0
- cat6 = 0
- cat7_corr = 0
- cat7 = 0
- cat8_corr = 0
- cat8 = 0
- cat9_corr = 0
- cat9 = 0
- predicted_labels = []
- with torch.no_grad():
- for i, (inputs, classes) in enumerate(dataloaders['test']):
- inputs = inputs.to(device)
- tmp_labels = model_ft(inputs)
- classes = classes.to(device)
- classes_list = classes.cpu().detach().numpy().tolist()
- print('classes list: ', classes_list)
- preds = preds.cpu().detach().numpy().tolist()
- outputs = model_ft(inputs)
- _, preds = torch.max(outputs, 1)
- preds_list = preds.cpu().detach().numpy().tolist()
- for i in range(4):
- if classes_list[i] == 1:
- cat1 += 1
- if classes_list[i] == preds[i]:
- cat1_corr += 1
- elif classes_list[i] == 2:
- cat2 += 1
- if classes_list[i] == preds[i]:
- cat2_corr += 1
- elif classes_list[i] == 3:
- cat3 += 1
- if classes_list[i] == preds[i]:
- cat3_corr += 1
- elif classes_list[i] == 4:
- cat4 += 1
- if classes_list[i] == preds[i]:
- cat4_corr += 1
- elif classes_list[i] == 5:
- cat5 += 1
- if classes_list[i] == preds[i]:
- cat5_corr += 1
- elif classes_list[i] == 6:
- cat6 += 1
- if classes_list[i] == preds[i]:
- cat6_corr += 1
- elif classes_list[i] == 7:
- cat7 += 1
- if classes_list[i] == preds[i]:
- cat7_corr += 1
- elif classes_list[i] == 8:
- cat8 += 1
- if classes_list[i] == preds[i]:
- cat8_corr += 1
- elif classes_list[i] == 9:
- print('here')
- cat9 += 1
- if classes_list[i] == preds[i]:
- cat9_corr += 1
- predicted_labels.append(preds.cpu().detach().numpy().tolist())
- for t, p in zip(classes.view(-1), preds.view(-1)):
- confusion_matrix[t.long(), p.long()] += 1
- print(confusion_matrix)
- cat1_acc = cat1_corr/cat1
- cat2_acc = cat2_corr/cat2
- cat3_acc = cat3_corr/cat3
- cat4_acc = cat4_corr/cat4
- cat5_acc = cat5_corr/cat5
- cat6_acc = cat6_corr/cat6
- cat7_acc = cat7_corr/cat7
- cat8_acc = cat8_corr/cat8
- #cat9_acc = cat9_corr/cat9
- print('1', cat1_acc)
- print('2', cat2_acc)
- print('3', cat3_acc)
- print('4', cat4_acc)
- print('5', cat5_acc)
- print('6', cat6_acc)
- print('7', cat7_acc)
- print('8', cat8_acc)
- -----------------------------------------------------------------------------------
- classes list: [2, 1, 5, 3]
- classes list: [1, 4, 6, 2]
- classes list: [3, 7, 6, 4]
- classes list: [5, 1, 8, 5]
- classes list: [2, 2, 2, 6]
- classes list: [4, 4, 1, 2]
- classes list: [1, 8, 6, 8]
- classes list: [2, 3, 2, 2]
- classes list: [2, 6, 4, 0]
- classes list: [1, 1, 1, 0]
- classes list: [1, 8, 6, 1]
- classes list: [1, 2, 1, 2]
- classes list: [1, 2, 6, 4]
- classes list: [5, 8, 8, 3]
- classes list: [4, 2, 4, 2]
- classes list: [2, 1, 2, 2]
- classes list: [6, 2, 4, 1]
- classes list: [2, 6, 1, 1]
- classes list: [3, 2, 2, 2]
- classes list: [6, 2, 2, 2]
- classes list: [8, 3, 7, 2]
- classes list: [4, 2, 6, 4]
- classes list: [5, 1, 6, 2]
- classes list: [2, 2, 1, 7]
- classes list: [4, 4, 1, 5]
- classes list: [6, 2, 2, 4]
- classes list: [2, 5, 3, 1]
- classes list: [0, 1, 5, 6]
- classes list: [2, 6, 6, 2]
- classes list: [6, 4, 2, 6]
- classes list: [4, 2, 6, 2]
- classes list: [6, 8, 2, 2]
- classes list: [5, 6, 8, 6]
- classes list: [1, 8, 0, 6]
- classes list: [2, 1, 7, 6]
- tensor([[ 0., 2., 1., 0., 0., 1., 0., 0., 0.],
- [ 0., 8., 10., 0., 0., 0., 4., 0., 2.],
- [ 0., 2., 29., 0., 3., 0., 7., 0., 1.],
- [ 0., 0., 0., 1., 4., 0., 2., 0., 0.],
- [ 0., 3., 4., 2., 0., 0., 7., 0., 0.],
- [ 0., 1., 3., 0., 0., 4., 1., 0., 0.],
- [ 0., 3., 3., 1., 0., 0., 13., 0., 4.],
- [ 0., 1., 0., 1., 1., 0., 1., 0., 0.],
- [ 0., 2., 4., 0., 0., 0., 2., 0., 2.]])
- 1 0.3333333333333333
- 2 0.6904761904761905
- 3 0.14285714285714285
- 4 0.0
- 5 0.4444444444444444
- 6 0.5416666666666666
- 7 0.0
- 8 0.2
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement