Advertisement
lamiastella

loocv_probs

Nov 28th, 2018
305
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.44 KB | None | 0 0
  1. mport torchvision.datasets as datasets
  2. class MonaDataset(datasets.folder.ImageFolder):
  3.     def __init__(self, root, transform=None, target_transform=None,
  4.                  loader=datasets.folder.default_loader):
  5.         super(MonaDataset, self).__init__(root, transform, target_transform, loader)
  6.  
  7.     def __getitem__(self, index):
  8.         path, target = self.samples[index]
  9.         sample = self.loader(path)
  10.         if self.transform is not None:
  11.             sample = self.transform(sample)
  12.         if self.target_transform is not None:
  13.             target = self.target_transform(target)
  14.         return sample, target, path
  15.  
  16. dataset = MonaDataset('10folds/10fold_9')
  17. print(len(dataset))
  18. x, y, im_path = dataset[0]
  19.  
  20.  
  21. print("x is: {}, y is: {}, im_path is: {}".format(x, y, im_path))
  22.  
  23. image_datasets = {x: MonaDataset(os.path.join(data_dir, x),
  24.                                           data_transforms[x])
  25.                   for x in ['train', 'test']}
  26. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  27.                                              shuffle=True, num_workers=4)
  28.               for x in ['train', 'test']}
  29. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
  30.  
  31.  
  32. class_names = image_datasets['train'].classes
  33.  
  34.  
  35.  
  36. nb_classes = 9
  37.  
  38. import torch.nn.functional as F
  39.  
  40. confusion_matrix = torch.zeros(nb_classes, nb_classes)
  41.  
  42. _classes = []
  43. _preds = []
  44. predicted_labels = []
  45. loocv_probs = []
  46.  
  47. with torch.no_grad():
  48.     for i, (inputs, classes, im_path) in enumerate(dataloaders['test']):
  49.  
  50.        
  51.      
  52.         inputs = inputs.to(device)
  53.         tmp_labels = model_ft(inputs)
  54.        
  55.         classes = classes.to(device)
  56.         classes_list = classes.cpu().detach().numpy().tolist()
  57.         _classes[:]=[i+1 for i in classes_list]
  58.         outputs = model_ft(inputs)
  59.        
  60.         gpu_tensor_probs = F.softmax(outputs, 1)
  61.         cpu_numpy_probs = gpu_tensor_probs.data.cpu().numpy()
  62.         loocv_probs.append(cpu_numpy_probs.tolist())
  63.    
  64.         _, preds = torch.max(outputs, 1)
  65.         preds_list = preds.cpu().detach().numpy().tolist()
  66.         _preds[:]=[i+1 for i in preds_list]
  67.          
  68.         predicted_labels.append(preds.cpu().detach().numpy().tolist())
  69.         for t, p in zip(classes.view(-1), preds.view(-1)):
  70.                 confusion_matrix[t.long(), p.long()] += 1
  71.                
  72. print(confusion_matrix)
  73. print(confusion_matrix.diag()/confusion_matrix.sum(1))
  74. #print('Class probabilities:', loocv_probs)
  75. print(len(loocv_probs))
  76.  
  77.  
  78. for i in range(len(loocv_probs)): #21
  79.     for j in range(len(loocv_probs[0])): #4
  80.         print(*[f"{element:.2f}" for element in loocv_probs[i][j]], sep=', ', end='\n')
  81.  
  82.  
  83. for i in range(9):
  84.     print("class {:d} --> accuracy: {:.2f}, correct predictions: {:d}, all: {:d}".format(i+1, (confusion_matrix.diag()/confusion_matrix.sum(1))[i]*100, int(confusion_matrix[i][i].numpy()), int(confusion_matrix.sum(dim=1)[i].numpy())))
  85.  
  86.  
  87.  
  88.  
  89. ---------------------------------------------------------------------------------
  90.  
  91.  
  92.  
  93. tensor([[ 0.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.],
  94.         [ 0.,  5.,  4.,  0.,  1.,  0.,  1.,  0.,  2.],
  95.         [ 0.,  4., 15.,  0.,  2.,  0.,  4.,  0.,  1.],
  96.         [ 0.,  0.,  0.,  0.,  0.,  1.,  3.,  0.,  0.],
  97.         [ 0.,  3.,  3.,  0.,  3.,  1.,  0.,  0.,  0.],
  98.         [ 0.,  2.,  0.,  0.,  0.,  2.,  1.,  0.,  0.],
  99.         [ 0.,  1.,  3.,  0.,  2.,  0.,  9.,  0.,  0.],
  100.         [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  1.],
  101.         [ 0.,  1.,  2.,  0.,  0.,  0.,  1.,  0.,  2.]])
  102. tensor([0.0000, 0.3846, 0.5769, 0.0000, 0.3000, 0.4000, 0.6000, 0.0000, 0.3333])
  103. 21
  104. 0.05, 0.13, 0.07, 0.00, 0.08, 0.00, 0.25, 0.02, 0.39
  105. 0.04, 0.09, 0.52, 0.03, 0.15, 0.01, 0.07, 0.05, 0.04
  106. 0.01, 0.08, 0.49, 0.05, 0.06, 0.12, 0.14, 0.04, 0.01
  107. 0.02, 0.25, 0.19, 0.11, 0.18, 0.14, 0.08, 0.02, 0.01
  108. 0.02, 0.10, 0.14, 0.07, 0.32, 0.01, 0.27, 0.05, 0.02
  109. 0.00, 0.01, 0.96, 0.00, 0.01, 0.00, 0.01, 0.00, 0.00
  110. 0.01, 0.40, 0.02, 0.01, 0.07, 0.00, 0.06, 0.03, 0.40
  111. 0.02, 0.10, 0.08, 0.05, 0.05, 0.47, 0.18, 0.01, 0.02
  112. 0.04, 0.02, 0.11, 0.02, 0.18, 0.01, 0.17, 0.04, 0.41
  113. 0.01, 0.22, 0.68, 0.01, 0.03, 0.01, 0.01, 0.01, 0.01
  114. 0.03, 0.35, 0.13, 0.15, 0.07, 0.01, 0.16, 0.06, 0.03
  115. 0.02, 0.07, 0.12, 0.02, 0.17, 0.19, 0.37, 0.02, 0.02
  116. 0.03, 0.02, 0.89, 0.00, 0.01, 0.01, 0.01, 0.01, 0.02
  117. 0.03, 0.34, 0.09, 0.05, 0.07, 0.23, 0.03, 0.05, 0.12
  118. 0.01, 0.01, 0.06, 0.05, 0.09, 0.01, 0.73, 0.03, 0.02
  119. 0.00, 0.67, 0.04, 0.01, 0.18, 0.00, 0.09, 0.01, 0.01
  120. 0.03, 0.23, 0.05, 0.04, 0.16, 0.02, 0.31, 0.04, 0.12
  121. 0.01, 0.11, 0.74, 0.02, 0.05, 0.01, 0.02, 0.01, 0.02
  122. 0.01, 0.09, 0.20, 0.06, 0.08, 0.47, 0.07, 0.01, 0.03
  123. 0.06, 0.08, 0.21, 0.02, 0.13, 0.00, 0.36, 0.10, 0.04
  124. 0.02, 0.03, 0.10, 0.01, 0.04, 0.00, 0.79, 0.01, 0.01
  125. 0.00, 0.09, 0.34, 0.18, 0.11, 0.22, 0.03, 0.02, 0.00
  126. 0.03, 0.26, 0.26, 0.06, 0.19, 0.04, 0.08, 0.05, 0.03
  127. 0.02, 0.09, 0.05, 0.00, 0.03, 0.01, 0.03, 0.02, 0.75
  128. 0.10, 0.07, 0.15, 0.01, 0.15, 0.00, 0.14, 0.02, 0.37
  129. 0.04, 0.13, 0.65, 0.02, 0.04, 0.05, 0.04, 0.01, 0.03
  130. 0.01, 0.20, 0.26, 0.05, 0.11, 0.19, 0.10, 0.06, 0.02
  131. 0.01, 0.10, 0.07, 0.18, 0.20, 0.03, 0.34, 0.06, 0.01
  132. 0.02, 0.03, 0.88, 0.01, 0.02, 0.00, 0.02, 0.01, 0.01
  133. 0.11, 0.30, 0.16, 0.05, 0.12, 0.08, 0.07, 0.06, 0.05
  134. 0.01, 0.06, 0.26, 0.02, 0.28, 0.04, 0.25, 0.03, 0.05
  135. 0.01, 0.34, 0.04, 0.06, 0.09, 0.04, 0.35, 0.03, 0.05
  136. 0.02, 0.66, 0.17, 0.01, 0.01, 0.00, 0.07, 0.02, 0.03
  137. 0.01, 0.11, 0.07, 0.05, 0.20, 0.45, 0.07, 0.02, 0.03
  138. 0.02, 0.03, 0.04, 0.09, 0.48, 0.03, 0.22, 0.05, 0.05
  139. 0.01, 0.02, 0.87, 0.00, 0.02, 0.01, 0.04, 0.00, 0.02
  140. 0.01, 0.01, 0.05, 0.02, 0.04, 0.00, 0.83, 0.01, 0.02
  141. 0.01, 0.71, 0.05, 0.02, 0.05, 0.05, 0.09, 0.01, 0.01
  142. 0.01, 0.01, 0.95, 0.00, 0.01, 0.00, 0.00, 0.01, 0.01
  143. 0.02, 0.10, 0.05, 0.15, 0.30, 0.09, 0.16, 0.07, 0.05
  144. 0.01, 0.01, 0.25, 0.11, 0.18, 0.01, 0.35, 0.07, 0.01
  145. 0.01, 0.49, 0.18, 0.01, 0.05, 0.13, 0.06, 0.00, 0.07
  146. 0.08, 0.09, 0.10, 0.01, 0.11, 0.00, 0.10, 0.04, 0.46
  147. 0.03, 0.27, 0.38, 0.05, 0.09, 0.10, 0.05, 0.04, 0.01
  148. 0.01, 0.22, 0.06, 0.04, 0.14, 0.14, 0.29, 0.05, 0.04
  149. 0.02, 0.07, 0.78, 0.00, 0.01, 0.00, 0.04, 0.00, 0.07
  150. 0.02, 0.13, 0.21, 0.04, 0.13, 0.24, 0.08, 0.04, 0.10
  151. 0.05, 0.08, 0.20, 0.09, 0.35, 0.01, 0.16, 0.06, 0.01
  152. 0.05, 0.06, 0.34, 0.18, 0.15, 0.02, 0.11, 0.08, 0.01
  153. 0.01, 0.31, 0.17, 0.04, 0.15, 0.02, 0.27, 0.02, 0.02
  154. 0.01, 0.14, 0.63, 0.02, 0.07, 0.02, 0.06, 0.02, 0.03
  155. 0.05, 0.07, 0.06, 0.01, 0.08, 0.03, 0.10, 0.02, 0.58
  156. 0.01, 0.01, 0.05, 0.12, 0.19, 0.05, 0.48, 0.05, 0.03
  157. 0.01, 0.39, 0.26, 0.04, 0.11, 0.05, 0.12, 0.02, 0.01
  158. 0.01, 0.05, 0.79, 0.01, 0.02, 0.01, 0.01, 0.01, 0.10
  159. 0.03, 0.36, 0.11, 0.01, 0.10, 0.01, 0.26, 0.05, 0.06
  160. 0.00, 0.07, 0.89, 0.00, 0.01, 0.01, 0.00, 0.00, 0.01
  161. 0.01, 0.04, 0.21, 0.12, 0.24, 0.12, 0.21, 0.04, 0.03
  162. 0.03, 0.31, 0.01, 0.01, 0.10, 0.01, 0.24, 0.02, 0.26
  163. 0.05, 0.07, 0.39, 0.08, 0.12, 0.01, 0.22, 0.05, 0.01
  164. 0.04, 0.16, 0.02, 0.01, 0.21, 0.00, 0.31, 0.03, 0.21
  165. 0.01, 0.04, 0.11, 0.13, 0.07, 0.59, 0.02, 0.02, 0.00
  166. 0.01, 0.14, 0.20, 0.03, 0.12, 0.03, 0.37, 0.03, 0.07
  167. 0.02, 0.03, 0.91, 0.00, 0.01, 0.00, 0.01, 0.01, 0.01
  168. 0.05, 0.39, 0.24, 0.04, 0.08, 0.09, 0.04, 0.01, 0.05
  169. 0.02, 0.12, 0.39, 0.02, 0.09, 0.03, 0.27, 0.03, 0.02
  170. 0.02, 0.06, 0.78, 0.00, 0.03, 0.00, 0.03, 0.02, 0.05
  171. 0.01, 0.06, 0.03, 0.14, 0.30, 0.05, 0.35, 0.04, 0.03
  172. 0.03, 0.15, 0.08, 0.04, 0.08, 0.03, 0.38, 0.12, 0.08
  173. 0.00, 0.48, 0.09, 0.01, 0.22, 0.02, 0.14, 0.01, 0.03
  174. 0.07, 0.02, 0.43, 0.14, 0.15, 0.03, 0.10, 0.04, 0.03
  175. 0.04, 0.11, 0.71, 0.01, 0.04, 0.02, 0.03, 0.01, 0.03
  176. 0.01, 0.03, 0.87, 0.00, 0.01, 0.00, 0.00, 0.00, 0.07
  177. 0.01, 0.11, 0.06, 0.08, 0.14, 0.07, 0.46, 0.04, 0.04
  178. 0.01, 0.07, 0.31, 0.21, 0.23, 0.07, 0.05, 0.05, 0.01
  179. 0.09, 0.19, 0.04, 0.04, 0.07, 0.03, 0.52, 0.02, 0.02
  180. 0.01, 0.20, 0.13, 0.02, 0.07, 0.05, 0.40, 0.02, 0.11
  181. 0.03, 0.25, 0.63, 0.01, 0.02, 0.04, 0.01, 0.01, 0.01
  182. 0.09, 0.07, 0.11, 0.10, 0.24, 0.08, 0.13, 0.05, 0.13
  183. 0.01, 0.05, 0.15, 0.05, 0.29, 0.00, 0.38, 0.06, 0.01
  184. 0.03, 0.09, 0.82, 0.00, 0.01, 0.00, 0.01, 0.01, 0.03
  185. 0.01, 0.16, 0.12, 0.05, 0.40, 0.02, 0.19, 0.01, 0.04
  186. 0.02, 0.04, 0.05, 0.08, 0.23, 0.12, 0.35, 0.09, 0.03
  187.  
  188. ---------------------------------------------------------------------------
  189. IndexError                                Traceback (most recent call last)
  190. <ipython-input-20-cb5baf0f1620> in <module>()
  191.      43 for i in range(len(loocv_probs)): #21
  192.      44     for j in range(len(loocv_probs[0])): #4
  193. ---> 45         print(*[f"{element:.2f}" for element in loocv_probs[i][j]], sep=', ', end='\n')
  194.      46
  195.      47
  196.  
  197. IndexError: list index out of range
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement