Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- mport torchvision.datasets as datasets
- class MonaDataset(datasets.folder.ImageFolder):
- def __init__(self, root, transform=None, target_transform=None,
- loader=datasets.folder.default_loader):
- super(MonaDataset, self).__init__(root, transform, target_transform, loader)
- def __getitem__(self, index):
- path, target = self.samples[index]
- sample = self.loader(path)
- if self.transform is not None:
- sample = self.transform(sample)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return sample, target, path
- dataset = MonaDataset('10folds/10fold_9')
- print(len(dataset))
- x, y, im_path = dataset[0]
- print("x is: {}, y is: {}, im_path is: {}".format(x, y, im_path))
- image_datasets = {x: MonaDataset(os.path.join(data_dir, x),
- data_transforms[x])
- for x in ['train', 'test']}
- dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
- shuffle=True, num_workers=4)
- for x in ['train', 'test']}
- dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
- class_names = image_datasets['train'].classes
- nb_classes = 9
- import torch.nn.functional as F
- confusion_matrix = torch.zeros(nb_classes, nb_classes)
- _classes = []
- _preds = []
- predicted_labels = []
- loocv_probs = []
- with torch.no_grad():
- for i, (inputs, classes, im_path) in enumerate(dataloaders['test']):
- inputs = inputs.to(device)
- tmp_labels = model_ft(inputs)
- classes = classes.to(device)
- classes_list = classes.cpu().detach().numpy().tolist()
- _classes[:]=[i+1 for i in classes_list]
- outputs = model_ft(inputs)
- gpu_tensor_probs = F.softmax(outputs, 1)
- cpu_numpy_probs = gpu_tensor_probs.data.cpu().numpy()
- loocv_probs.append(cpu_numpy_probs.tolist())
- _, preds = torch.max(outputs, 1)
- preds_list = preds.cpu().detach().numpy().tolist()
- _preds[:]=[i+1 for i in preds_list]
- predicted_labels.append(preds.cpu().detach().numpy().tolist())
- for t, p in zip(classes.view(-1), preds.view(-1)):
- confusion_matrix[t.long(), p.long()] += 1
- print(confusion_matrix)
- print(confusion_matrix.diag()/confusion_matrix.sum(1))
- #print('Class probabilities:', loocv_probs)
- print(len(loocv_probs))
- for i in range(len(loocv_probs)): #21
- for j in range(len(loocv_probs[0])): #4
- print(*[f"{element:.2f}" for element in loocv_probs[i][j]], sep=', ', end='\n')
- for i in range(9):
- print("class {:d} --> accuracy: {:.2f}, correct predictions: {:d}, all: {:d}".format(i+1, (confusion_matrix.diag()/confusion_matrix.sum(1))[i]*100, int(confusion_matrix[i][i].numpy()), int(confusion_matrix.sum(dim=1)[i].numpy())))
- ---------------------------------------------------------------------------------
- tensor([[ 0., 0., 1., 0., 0., 1., 0., 0., 0.],
- [ 0., 5., 4., 0., 1., 0., 1., 0., 2.],
- [ 0., 4., 15., 0., 2., 0., 4., 0., 1.],
- [ 0., 0., 0., 0., 0., 1., 3., 0., 0.],
- [ 0., 3., 3., 0., 3., 1., 0., 0., 0.],
- [ 0., 2., 0., 0., 0., 2., 1., 0., 0.],
- [ 0., 1., 3., 0., 2., 0., 9., 0., 0.],
- [ 0., 0., 0., 0., 0., 0., 1., 0., 1.],
- [ 0., 1., 2., 0., 0., 0., 1., 0., 2.]])
- tensor([0.0000, 0.3846, 0.5769, 0.0000, 0.3000, 0.4000, 0.6000, 0.0000, 0.3333])
- 21
- 0.05, 0.13, 0.07, 0.00, 0.08, 0.00, 0.25, 0.02, 0.39
- 0.04, 0.09, 0.52, 0.03, 0.15, 0.01, 0.07, 0.05, 0.04
- 0.01, 0.08, 0.49, 0.05, 0.06, 0.12, 0.14, 0.04, 0.01
- 0.02, 0.25, 0.19, 0.11, 0.18, 0.14, 0.08, 0.02, 0.01
- 0.02, 0.10, 0.14, 0.07, 0.32, 0.01, 0.27, 0.05, 0.02
- 0.00, 0.01, 0.96, 0.00, 0.01, 0.00, 0.01, 0.00, 0.00
- 0.01, 0.40, 0.02, 0.01, 0.07, 0.00, 0.06, 0.03, 0.40
- 0.02, 0.10, 0.08, 0.05, 0.05, 0.47, 0.18, 0.01, 0.02
- 0.04, 0.02, 0.11, 0.02, 0.18, 0.01, 0.17, 0.04, 0.41
- 0.01, 0.22, 0.68, 0.01, 0.03, 0.01, 0.01, 0.01, 0.01
- 0.03, 0.35, 0.13, 0.15, 0.07, 0.01, 0.16, 0.06, 0.03
- 0.02, 0.07, 0.12, 0.02, 0.17, 0.19, 0.37, 0.02, 0.02
- 0.03, 0.02, 0.89, 0.00, 0.01, 0.01, 0.01, 0.01, 0.02
- 0.03, 0.34, 0.09, 0.05, 0.07, 0.23, 0.03, 0.05, 0.12
- 0.01, 0.01, 0.06, 0.05, 0.09, 0.01, 0.73, 0.03, 0.02
- 0.00, 0.67, 0.04, 0.01, 0.18, 0.00, 0.09, 0.01, 0.01
- 0.03, 0.23, 0.05, 0.04, 0.16, 0.02, 0.31, 0.04, 0.12
- 0.01, 0.11, 0.74, 0.02, 0.05, 0.01, 0.02, 0.01, 0.02
- 0.01, 0.09, 0.20, 0.06, 0.08, 0.47, 0.07, 0.01, 0.03
- 0.06, 0.08, 0.21, 0.02, 0.13, 0.00, 0.36, 0.10, 0.04
- 0.02, 0.03, 0.10, 0.01, 0.04, 0.00, 0.79, 0.01, 0.01
- 0.00, 0.09, 0.34, 0.18, 0.11, 0.22, 0.03, 0.02, 0.00
- 0.03, 0.26, 0.26, 0.06, 0.19, 0.04, 0.08, 0.05, 0.03
- 0.02, 0.09, 0.05, 0.00, 0.03, 0.01, 0.03, 0.02, 0.75
- 0.10, 0.07, 0.15, 0.01, 0.15, 0.00, 0.14, 0.02, 0.37
- 0.04, 0.13, 0.65, 0.02, 0.04, 0.05, 0.04, 0.01, 0.03
- 0.01, 0.20, 0.26, 0.05, 0.11, 0.19, 0.10, 0.06, 0.02
- 0.01, 0.10, 0.07, 0.18, 0.20, 0.03, 0.34, 0.06, 0.01
- 0.02, 0.03, 0.88, 0.01, 0.02, 0.00, 0.02, 0.01, 0.01
- 0.11, 0.30, 0.16, 0.05, 0.12, 0.08, 0.07, 0.06, 0.05
- 0.01, 0.06, 0.26, 0.02, 0.28, 0.04, 0.25, 0.03, 0.05
- 0.01, 0.34, 0.04, 0.06, 0.09, 0.04, 0.35, 0.03, 0.05
- 0.02, 0.66, 0.17, 0.01, 0.01, 0.00, 0.07, 0.02, 0.03
- 0.01, 0.11, 0.07, 0.05, 0.20, 0.45, 0.07, 0.02, 0.03
- 0.02, 0.03, 0.04, 0.09, 0.48, 0.03, 0.22, 0.05, 0.05
- 0.01, 0.02, 0.87, 0.00, 0.02, 0.01, 0.04, 0.00, 0.02
- 0.01, 0.01, 0.05, 0.02, 0.04, 0.00, 0.83, 0.01, 0.02
- 0.01, 0.71, 0.05, 0.02, 0.05, 0.05, 0.09, 0.01, 0.01
- 0.01, 0.01, 0.95, 0.00, 0.01, 0.00, 0.00, 0.01, 0.01
- 0.02, 0.10, 0.05, 0.15, 0.30, 0.09, 0.16, 0.07, 0.05
- 0.01, 0.01, 0.25, 0.11, 0.18, 0.01, 0.35, 0.07, 0.01
- 0.01, 0.49, 0.18, 0.01, 0.05, 0.13, 0.06, 0.00, 0.07
- 0.08, 0.09, 0.10, 0.01, 0.11, 0.00, 0.10, 0.04, 0.46
- 0.03, 0.27, 0.38, 0.05, 0.09, 0.10, 0.05, 0.04, 0.01
- 0.01, 0.22, 0.06, 0.04, 0.14, 0.14, 0.29, 0.05, 0.04
- 0.02, 0.07, 0.78, 0.00, 0.01, 0.00, 0.04, 0.00, 0.07
- 0.02, 0.13, 0.21, 0.04, 0.13, 0.24, 0.08, 0.04, 0.10
- 0.05, 0.08, 0.20, 0.09, 0.35, 0.01, 0.16, 0.06, 0.01
- 0.05, 0.06, 0.34, 0.18, 0.15, 0.02, 0.11, 0.08, 0.01
- 0.01, 0.31, 0.17, 0.04, 0.15, 0.02, 0.27, 0.02, 0.02
- 0.01, 0.14, 0.63, 0.02, 0.07, 0.02, 0.06, 0.02, 0.03
- 0.05, 0.07, 0.06, 0.01, 0.08, 0.03, 0.10, 0.02, 0.58
- 0.01, 0.01, 0.05, 0.12, 0.19, 0.05, 0.48, 0.05, 0.03
- 0.01, 0.39, 0.26, 0.04, 0.11, 0.05, 0.12, 0.02, 0.01
- 0.01, 0.05, 0.79, 0.01, 0.02, 0.01, 0.01, 0.01, 0.10
- 0.03, 0.36, 0.11, 0.01, 0.10, 0.01, 0.26, 0.05, 0.06
- 0.00, 0.07, 0.89, 0.00, 0.01, 0.01, 0.00, 0.00, 0.01
- 0.01, 0.04, 0.21, 0.12, 0.24, 0.12, 0.21, 0.04, 0.03
- 0.03, 0.31, 0.01, 0.01, 0.10, 0.01, 0.24, 0.02, 0.26
- 0.05, 0.07, 0.39, 0.08, 0.12, 0.01, 0.22, 0.05, 0.01
- 0.04, 0.16, 0.02, 0.01, 0.21, 0.00, 0.31, 0.03, 0.21
- 0.01, 0.04, 0.11, 0.13, 0.07, 0.59, 0.02, 0.02, 0.00
- 0.01, 0.14, 0.20, 0.03, 0.12, 0.03, 0.37, 0.03, 0.07
- 0.02, 0.03, 0.91, 0.00, 0.01, 0.00, 0.01, 0.01, 0.01
- 0.05, 0.39, 0.24, 0.04, 0.08, 0.09, 0.04, 0.01, 0.05
- 0.02, 0.12, 0.39, 0.02, 0.09, 0.03, 0.27, 0.03, 0.02
- 0.02, 0.06, 0.78, 0.00, 0.03, 0.00, 0.03, 0.02, 0.05
- 0.01, 0.06, 0.03, 0.14, 0.30, 0.05, 0.35, 0.04, 0.03
- 0.03, 0.15, 0.08, 0.04, 0.08, 0.03, 0.38, 0.12, 0.08
- 0.00, 0.48, 0.09, 0.01, 0.22, 0.02, 0.14, 0.01, 0.03
- 0.07, 0.02, 0.43, 0.14, 0.15, 0.03, 0.10, 0.04, 0.03
- 0.04, 0.11, 0.71, 0.01, 0.04, 0.02, 0.03, 0.01, 0.03
- 0.01, 0.03, 0.87, 0.00, 0.01, 0.00, 0.00, 0.00, 0.07
- 0.01, 0.11, 0.06, 0.08, 0.14, 0.07, 0.46, 0.04, 0.04
- 0.01, 0.07, 0.31, 0.21, 0.23, 0.07, 0.05, 0.05, 0.01
- 0.09, 0.19, 0.04, 0.04, 0.07, 0.03, 0.52, 0.02, 0.02
- 0.01, 0.20, 0.13, 0.02, 0.07, 0.05, 0.40, 0.02, 0.11
- 0.03, 0.25, 0.63, 0.01, 0.02, 0.04, 0.01, 0.01, 0.01
- 0.09, 0.07, 0.11, 0.10, 0.24, 0.08, 0.13, 0.05, 0.13
- 0.01, 0.05, 0.15, 0.05, 0.29, 0.00, 0.38, 0.06, 0.01
- 0.03, 0.09, 0.82, 0.00, 0.01, 0.00, 0.01, 0.01, 0.03
- 0.01, 0.16, 0.12, 0.05, 0.40, 0.02, 0.19, 0.01, 0.04
- 0.02, 0.04, 0.05, 0.08, 0.23, 0.12, 0.35, 0.09, 0.03
- ---------------------------------------------------------------------------
- IndexError Traceback (most recent call last)
- <ipython-input-20-cb5baf0f1620> in <module>()
- 43 for i in range(len(loocv_probs)): #21
- 44 for j in range(len(loocv_probs[0])): #4
- ---> 45 print(*[f"{element:.2f}" for element in loocv_probs[i][j]], sep=', ', end='\n')
- 46
- 47
- IndexError: list index out of range
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement