Advertisement
lamiastella

per class acc

Nov 13th, 2018
358
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.62 KB | None | 0 0
  1. nb_classes = 9
  2.  
  3.  
  4. confusion_matrix = torch.zeros(nb_classes, nb_classes)
  5. cat1_corr = 0
  6. cat1 = 0
  7. cat2_corr = 0
  8. cat2 = 0
  9. cat3_corr = 0
  10. cat3 = 0
  11. cat4_corr = 0
  12. cat4 = 0
  13. cat5_corr = 0
  14. cat5 = 0
  15. cat6_corr = 0
  16. cat6 = 0
  17. cat7_corr = 0
  18. cat7 = 0
  19. cat8_corr = 0
  20. cat8 = 0
  21. cat9_corr = 0
  22. cat9 = 0
  23.  
  24. predicted_labels = []
  25. with torch.no_grad():
  26.     for i, (inputs, classes) in enumerate(dataloaders['test']):
  27.        
  28.      
  29.         inputs = inputs.to(device)
  30.         tmp_labels = model_ft(inputs)
  31.        
  32.         classes = classes.to(device)
  33.         classes_list = classes.cpu().detach().numpy().tolist()
  34.         print('classes list: ', classes_list)
  35.         preds = preds.cpu().detach().numpy().tolist()
  36.         outputs = model_ft(inputs)
  37.         _, preds = torch.max(outputs, 1)
  38.         preds_list = preds.cpu().detach().numpy().tolist()
  39.         for i in range(4):
  40.             if classes_list[i] == 1:
  41.                 cat1 += 1
  42.                 if classes_list[i] == preds[i]:
  43.                     cat1_corr += 1
  44.             elif classes_list[i] == 2:
  45.                 cat2 += 1
  46.                 if classes_list[i] == preds[i]:
  47.                     cat2_corr += 1
  48.             elif classes_list[i] == 3:
  49.                 cat3 += 1
  50.                 if classes_list[i] == preds[i]:
  51.                     cat3_corr += 1
  52.             elif classes_list[i] == 4:
  53.                 cat4 += 1
  54.                 if classes_list[i] == preds[i]:
  55.                     cat4_corr += 1
  56.             elif classes_list[i] == 5:
  57.                 cat5 += 1
  58.                 if classes_list[i] == preds[i]:
  59.                     cat5_corr += 1
  60.             elif classes_list[i] == 6:
  61.                 cat6 += 1
  62.                 if classes_list[i] == preds[i]:
  63.                     cat6_corr += 1
  64.             elif classes_list[i] == 7:
  65.                 cat7 += 1
  66.                 if classes_list[i] == preds[i]:
  67.                     cat7_corr += 1
  68.             elif classes_list[i] == 8:
  69.                 cat8 += 1
  70.                 if classes_list[i] == preds[i]:
  71.                     cat8_corr += 1
  72.             elif classes_list[i] == 9:
  73.                 print('here')
  74.                 cat9 += 1
  75.                 if classes_list[i] == preds[i]:
  76.                     cat9_corr += 1
  77.                    
  78.                    
  79.         predicted_labels.append(preds.cpu().detach().numpy().tolist())
  80.         for t, p in zip(classes.view(-1), preds.view(-1)):
  81.                 confusion_matrix[t.long(), p.long()] += 1
  82.                
  83. print(confusion_matrix)
  84.  
  85.  
  86. cat1_acc = cat1_corr/cat1
  87. cat2_acc = cat2_corr/cat2
  88. cat3_acc = cat3_corr/cat3
  89. cat4_acc = cat4_corr/cat4
  90. cat5_acc = cat5_corr/cat5
  91. cat6_acc = cat6_corr/cat6
  92. cat7_acc = cat7_corr/cat7
  93. cat8_acc = cat8_corr/cat8
  94. #cat9_acc = cat9_corr/cat9
  95.  
  96.  
  97. print('1', cat1_acc)
  98. print('2', cat2_acc)
  99. print('3', cat3_acc)
  100. print('4', cat4_acc)
  101. print('5', cat5_acc)
  102. print('6', cat6_acc)
  103. print('7', cat7_acc)
  104. print('8', cat8_acc)
  105.  
  106.  
  107.  
  108. -----------------------------------------------------------------------------------
  109.  
  110. classes list:  [2, 1, 5, 3]
  111. classes list:  [1, 4, 6, 2]
  112. classes list:  [3, 7, 6, 4]
  113. classes list:  [5, 1, 8, 5]
  114. classes list:  [2, 2, 2, 6]
  115. classes list:  [4, 4, 1, 2]
  116. classes list:  [1, 8, 6, 8]
  117. classes list:  [2, 3, 2, 2]
  118. classes list:  [2, 6, 4, 0]
  119. classes list:  [1, 1, 1, 0]
  120. classes list:  [1, 8, 6, 1]
  121. classes list:  [1, 2, 1, 2]
  122. classes list:  [1, 2, 6, 4]
  123. classes list:  [5, 8, 8, 3]
  124. classes list:  [4, 2, 4, 2]
  125. classes list:  [2, 1, 2, 2]
  126. classes list:  [6, 2, 4, 1]
  127. classes list:  [2, 6, 1, 1]
  128. classes list:  [3, 2, 2, 2]
  129. classes list:  [6, 2, 2, 2]
  130. classes list:  [8, 3, 7, 2]
  131. classes list:  [4, 2, 6, 4]
  132. classes list:  [5, 1, 6, 2]
  133. classes list:  [2, 2, 1, 7]
  134. classes list:  [4, 4, 1, 5]
  135. classes list:  [6, 2, 2, 4]
  136. classes list:  [2, 5, 3, 1]
  137. classes list:  [0, 1, 5, 6]
  138. classes list:  [2, 6, 6, 2]
  139. classes list:  [6, 4, 2, 6]
  140. classes list:  [4, 2, 6, 2]
  141. classes list:  [6, 8, 2, 2]
  142. classes list:  [5, 6, 8, 6]
  143. classes list:  [1, 8, 0, 6]
  144. classes list:  [2, 1, 7, 6]
  145. tensor([[ 0.,  2.,  1.,  0.,  0.,  1.,  0.,  0.,  0.],
  146.         [ 0.,  8., 10.,  0.,  0.,  0.,  4.,  0.,  2.],
  147.         [ 0.,  2., 29.,  0.,  3.,  0.,  7.,  0.,  1.],
  148.         [ 0.,  0.,  0.,  1.,  4.,  0.,  2.,  0.,  0.],
  149.         [ 0.,  3.,  4.,  2.,  0.,  0.,  7.,  0.,  0.],
  150.         [ 0.,  1.,  3.,  0.,  0.,  4.,  1.,  0.,  0.],
  151.         [ 0.,  3.,  3.,  1.,  0.,  0., 13.,  0.,  4.],
  152.         [ 0.,  1.,  0.,  1.,  1.,  0.,  1.,  0.,  0.],
  153.         [ 0.,  2.,  4.,  0.,  0.,  0.,  2.,  0.,  2.]])
  154. 1 0.3333333333333333
  155. 2 0.6904761904761905
  156. 3 0.14285714285714285
  157. 4 0.0
  158. 5 0.4444444444444444
  159. 6 0.5416666666666666
  160. 7 0.0
  161. 8 0.2
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement