Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- data_transforms = {
- 'train': transforms.Compose([
- transforms.RandomResizedCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.RandomRotation(20),
- transforms.ColorJitter(0.3, 0.3, 0.3),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- # 'val': transforms.Compose([
- # transforms.Resize(256),
- # transforms.CenterCrop(224),
- # transforms.ToTensor(),
- # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- # ]),
- 'test': transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- }
- data_dir = "10folds/10fold_9"
- class MonaDataset(datasets.folder.ImageFolder):
- def __init__(self, root, transform=None, target_transform=None,
- loader=datasets.folder.default_loader):
- super(MonaDataset, self).__init__(root, transform, target_transform, loader)
- def __getitem__(self, index):
- path, target = self.samples[index]
- sample = self.loader(path)
- if self.transform is not None:
- sample = self.transform(sample)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return sample, target, path
- image_datasets = {x: MonaDataset(os.path.join(data_dir, x),
- data_transforms[x])
- for x in ['train', 'test']}
- dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
- shuffle=True, num_workers=4)
- for x in ['train', 'test']}
- targets = []
- with torch.no_grad():
- for i, (inputs, classes, im_path) in enumerate(dataloaders['test']):
- targets.append(classes.cpu().detach().numpy().tolist())
- print(targets)
- flat_targets = [item for sublist in targets for item in sublist]
- flat_targets = torch.Tensor(flat_targets)
- print(flat_targets)
- class_sample_count = torch.tensor([(flat_targets == t).sum() for t in torch.unique(flat_targets, sorted=True)])
- print(class_sample_count)
- weight = 1. / class_sample_count.float()
- print(weight)
- samples_weight = torch.tensor([weight[t] for t in flat_targets])
- sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
- dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
- shuffle=True, num_workers=4, sampler = sampler)
- for x in ['train', 'test']}
- dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
- class_names = image_datasets['train'].classes
- ---------------------------------------------------------------------------
- [[6, 2, 2, 1], [6, 1, 6, 2], [4, 1, 2, 5], [8, 1, 6, 0], [6, 4, 1, 4], [6, 8, 6, 2], [2, 2, 2, 8], [2, 2, 5, 4], [4, 1, 2, 5], [1, 6, 5, 3], [2, 0, 1, 1], [5, 6, 2, 2], [6, 2, 8, 4], [1, 6, 4, 2], [3, 8, 1, 1], [2, 2, 6, 2], [3, 6, 2, 4], [7, 4, 6, 8], [2, 4, 3, 2], [2, 7, 2, 1], [2, 2, 6]]
- tensor([6., 2., 2., 1., 6., 1., 6., 2., 4., 1., 2., 5., 8., 1., 6., 0., 6., 4.,
- 1., 4., 6., 8., 6., 2., 2., 2., 2., 8., 2., 2., 5., 4., 4., 1., 2., 5.,
- 1., 6., 5., 3., 2., 0., 1., 1., 5., 6., 2., 2., 6., 2., 8., 4., 1., 6.,
- 4., 2., 3., 8., 1., 1., 2., 2., 6., 2., 3., 6., 2., 4., 7., 4., 6., 8.,
- 2., 4., 3., 2., 2., 7., 2., 1., 2., 2., 6.])
- tensor([ 2, 13, 26, 4, 10, 5, 15, 2, 6])
- tensor([0.5000, 0.0769, 0.0385, 0.2500, 0.1000, 0.2000, 0.0667, 0.5000, 0.1667])
- ---------------------------------------------------------------------------
- RuntimeError Traceback (most recent call last)
- <ipython-input-254-06eff51f165c> in <module>()
- 92 weight = 1. / class_sample_count.float()
- 93 print(weight)
- ---> 94 samples_weight = torch.tensor([weight[t] for t in flat_targets])
- 95 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
- 96
- <ipython-input-254-06eff51f165c> in <listcomp>(.0)
- 92 weight = 1. / class_sample_count.float()
- 93 print(weight)
- ---> 94 samples_weight = torch.tensor([weight[t] for t in flat_targets])
- 95 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
- 96
- RuntimeError: tensors used as indices must be long or byte tensors
Add Comment
Please, Sign In to add comment