Advertisement
lamiastella

Untitled

Nov 13th, 2018
283
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.02 KB | None | 0 0
  1. nb_classes = 9
  2.  
  3.  
  4. confusion_matrix = torch.zeros(nb_classes, nb_classes)
  5. cat1_corr = 0
  6. cat1 = 0
  7. cat2_corr = 0
  8. cat2 = 0
  9. cat3_corr = 0
  10. cat3 = 0
  11. cat4_corr = 0
  12. cat4 = 0
  13. cat5_corr = 0
  14. cat5 = 0
  15. cat6_corr = 0
  16. cat6 = 0
  17. cat7_corr = 0
  18. cat7 = 0
  19. cat8_corr = 0
  20. cat8 = 0
  21. cat9_corr = 0
  22. cat9 = 0
  23. _classes = []
  24. predicted_labels = []
  25. with torch.no_grad():
  26.     for i, (inputs, classes) in enumerate(dataloaders['test']):
  27.        
  28.      
  29.         inputs = inputs.to(device)
  30.         tmp_labels = model_ft(inputs)
  31.        
  32.         classes = classes.to(device)
  33.         classes_list = classes.cpu().detach().numpy().tolist()
  34.         print('classes list: ', classes_list)
  35.         _classes[:]=[i+1 for i in classes_list]
  36.         print('_classes: ', _classes)
  37.         preds = preds.cpu().detach().numpy().tolist()
  38.         outputs = model_ft(inputs)
  39.         _, preds = torch.max(outputs, 1)
  40.         preds_list = preds.cpu().detach().numpy().tolist()
  41.         _preds[:]=[i+1 for i in preds_list]
  42.         for i in range(4):
  43.             if _classes[i] == 1:
  44.                 cat1 += 1
  45.                 if _classes[i] == _preds[i]:
  46.                     cat1_corr += 1
  47.             elif _classes[i] == 2:
  48.                 cat2 += 1
  49.                 if _classes[i] == _preds[i]:
  50.                     cat2_corr += 1
  51.             elif _classes[i] == 3:
  52.                 cat3 += 1
  53.                 if _classes[i] == _preds[i]:
  54.                     cat3_corr += 1
  55.             elif _classes[i] == 4:
  56.                 cat4 += 1
  57.                 if _classes[i] == _preds[i]:
  58.                     cat4_corr += 1
  59.             elif _classes[i] == 5:
  60.                 cat5 += 1
  61.                 if _classes[i] == _preds[i]:
  62.                     cat5_corr += 1
  63.             elif _classes[i] == 6:
  64.                 cat6 += 1
  65.                 if _classes[i] == _preds[i]:
  66.                     cat6_corr += 1
  67.             elif _classes[i] == 7:
  68.                 cat7 += 1
  69.                 if _classes[i] == _preds[i]:
  70.                     cat7_corr += 1
  71.             elif _classes[i] == 8:
  72.                 cat8 += 1
  73.                 if _classes[i] == _preds[i]:
  74.                     cat8_corr += 1
  75.             elif _classes[i] == 9:
  76.                 print('here')
  77.                 cat9 += 1
  78.                 if _classes[i] == _preds[i]:
  79.                     cat9_corr += 1
  80.                    
  81.                    
  82.         predicted_labels.append(preds.cpu().detach().numpy().tolist())
  83.         for t, p in zip(classes.view(-1), preds.view(-1)):
  84.                 confusion_matrix[t.long(), p.long()] += 1
  85.                
  86. print(confusion_matrix)
  87.  
  88.  
  89. cat1_acc = cat1_corr/cat1
  90. cat2_acc = cat2_corr/cat2
  91. cat3_acc = cat3_corr/cat3
  92. cat4_acc = cat4_corr/cat4
  93. cat5_acc = cat5_corr/cat5
  94. cat6_acc = cat6_corr/cat6
  95. cat7_acc = cat7_corr/cat7
  96. cat8_acc = cat8_corr/cat8
  97. cat9_acc = cat9_corr/cat9
  98.  
  99.  
  100. print('1', cat1_acc)
  101. print('2', cat2_acc)
  102. print('3', cat3_acc)
  103. print('4', cat4_acc)
  104. print('5', cat5_acc)
  105. print('6', cat6_acc)
  106. print('7', cat7_acc)
  107. print('8', cat8_acc)
  108. print('9', cat9_acc)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement