lamiastella

weight sampler

Nov 29th, 2018
245
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.41 KB | None | 0 0
  1. data_transforms = {
  2.     'train': transforms.Compose([
  3.         transforms.RandomResizedCrop(224),
  4.         transforms.RandomHorizontalFlip(),
  5.         transforms.RandomRotation(20),
  6.         transforms.ColorJitter(0.3, 0.3, 0.3),
  7.         transforms.ToTensor(),
  8.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  9.     ]),
  10. #    'val': transforms.Compose([
  11. #        transforms.Resize(256),
  12. #        transforms.CenterCrop(224),
  13. #        transforms.ToTensor(),
  14. #        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  15. #    ]),
  16.    
  17.         'test': transforms.Compose([
  18.         transforms.Resize(256),
  19.         transforms.CenterCrop(224),
  20.         transforms.ToTensor(),
  21.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  22.     ]),
  23. }
  24.  
  25.  
  26. data_dir = "10folds/10fold_9"
  27.  
  28. class MonaDataset(datasets.folder.ImageFolder):
  29.     def __init__(self, root, transform=None, target_transform=None,
  30.                  loader=datasets.folder.default_loader):
  31.         super(MonaDataset, self).__init__(root, transform, target_transform, loader)
  32.  
  33.     def __getitem__(self, index):
  34.         path, target = self.samples[index]
  35.         sample = self.loader(path)
  36.         if self.transform is not None:
  37.             sample = self.transform(sample)
  38.         if self.target_transform is not None:
  39.             target = self.target_transform(target)
  40.         return sample, target, path
  41.  
  42.  
  43. image_datasets = {x: MonaDataset(os.path.join(data_dir, x),
  44.                                           data_transforms[x])
  45.                   for x in ['train', 'test']}
  46. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  47.                                              shuffle=True, num_workers=4)
  48.               for x in ['train', 'test']}
  49.  
  50.  
  51.  
  52. targets = []
  53. with torch.no_grad():
  54.     for i, (inputs, classes, im_path) in enumerate(dataloaders['test']):
  55.         targets.append(classes.cpu().detach().numpy().tolist())
  56.        
  57. print(targets)
  58. flat_targets = [item for sublist in targets for item in sublist]
  59. flat_targets = torch.Tensor(flat_targets)
  60. print(flat_targets)
  61. class_sample_count = torch.tensor([(flat_targets == t).sum() for t in torch.unique(flat_targets, sorted=True)])
  62. print(class_sample_count)
  63. weight = 1. / class_sample_count.float()
  64. print(weight)
  65. samples_weight = torch.tensor([weight[t] for t in flat_targets])
  66. sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
  67.  
  68.  
  69. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  70.                                              shuffle=True, num_workers=4, sampler = sampler)
  71.               for x in ['train', 'test']}
  72. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
  73.  
  74.  
  75. class_names = image_datasets['train'].classes
  76.  
  77.  
  78. ---------------------------------------------------------------------------
  79.  
  80. [[6, 2, 2, 1], [6, 1, 6, 2], [4, 1, 2, 5], [8, 1, 6, 0], [6, 4, 1, 4], [6, 8, 6, 2], [2, 2, 2, 8], [2, 2, 5, 4], [4, 1, 2, 5], [1, 6, 5, 3], [2, 0, 1, 1], [5, 6, 2, 2], [6, 2, 8, 4], [1, 6, 4, 2], [3, 8, 1, 1], [2, 2, 6, 2], [3, 6, 2, 4], [7, 4, 6, 8], [2, 4, 3, 2], [2, 7, 2, 1], [2, 2, 6]]
  81. tensor([6., 2., 2., 1., 6., 1., 6., 2., 4., 1., 2., 5., 8., 1., 6., 0., 6., 4.,
  82.         1., 4., 6., 8., 6., 2., 2., 2., 2., 8., 2., 2., 5., 4., 4., 1., 2., 5.,
  83.         1., 6., 5., 3., 2., 0., 1., 1., 5., 6., 2., 2., 6., 2., 8., 4., 1., 6.,
  84.         4., 2., 3., 8., 1., 1., 2., 2., 6., 2., 3., 6., 2., 4., 7., 4., 6., 8.,
  85.         2., 4., 3., 2., 2., 7., 2., 1., 2., 2., 6.])
  86. tensor([ 2, 13, 26,  4, 10,  5, 15,  2,  6])
  87. tensor([0.5000, 0.0769, 0.0385, 0.2500, 0.1000, 0.2000, 0.0667, 0.5000, 0.1667])
  88.  
  89. ---------------------------------------------------------------------------
  90. RuntimeError                              Traceback (most recent call last)
  91. <ipython-input-254-06eff51f165c> in <module>()
  92.      92 weight = 1. / class_sample_count.float()
  93.      93 print(weight)
  94. ---> 94 samples_weight = torch.tensor([weight[t] for t in flat_targets])
  95.      95 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
  96.      96
  97.  
  98. <ipython-input-254-06eff51f165c> in <listcomp>(.0)
  99.      92 weight = 1. / class_sample_count.float()
  100.      93 print(weight)
  101. ---> 94 samples_weight = torch.tensor([weight[t] for t in flat_targets])
  102.      95 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
  103.      96
  104.  
  105. RuntimeError: tensors used as indices must be long or byte tensors
Add Comment
Please, Sign In to add comment