Advertisement
lamiastella

weighted random sampler

Nov 29th, 2018
215
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.88 KB | None | 0 0
  1. ######################################################################
  2. # Load Data
  3. # ---------
  4. #
  5. # We will use torchvision and torch.utils.data packages for loading the
  6. # data.
  7. #
  8. # The problem we're going to solve today is to train a model to classify
  9. # **ants** and **bees**. We have about 120 training images each for ants and bees.
  10. # There are 75 validation images for each class. Usually, this is a very
  11. # small dataset to generalize upon, if trained from scratch. Since we
  12. # are using transfer learning, we should be able to generalize reasonably
  13. # well.
  14. #
  15. # This dataset is a very small subset of imagenet.
  16. #
  17. # .. Note ::
  18. #    Download the data from
  19. #    `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
  20. #    and extract it to the current directory.
  21.  
  22. # Data augmentation and normalization for training
  23. # Just normalization for validation
  24.  
  25.  
  26. #using the weight sampling method from https://github.com/ptrblck/pytorch_misc/blob/master/weighted_sampling.py#L25
  27.  
  28.  
  29. data_transforms = {
  30.     'train': transforms.Compose([
  31.         transforms.RandomResizedCrop(224),
  32.         transforms.RandomHorizontalFlip(),
  33.         transforms.RandomRotation(20),
  34.         transforms.ColorJitter(0.3, 0.3, 0.3),
  35.         transforms.ToTensor(),
  36.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  37.     ]),
  38. #    'val': transforms.Compose([
  39. #        transforms.Resize(256),
  40. #        transforms.CenterCrop(224),
  41. #        transforms.ToTensor(),
  42. #        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  43. #    ]),
  44.    
  45.         'test': transforms.Compose([
  46.         transforms.Resize(256),
  47.         transforms.CenterCrop(224),
  48.         transforms.ToTensor(),
  49.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  50.     ]),
  51. }
  52.  
  53.  
  54. data_dir = "10folds/10fold_9"
  55.  
  56. class MonaDataset(datasets.folder.ImageFolder):
  57.     def __init__(self, root, transform=None, target_transform=None,
  58.                  loader=datasets.folder.default_loader):
  59.         super(MonaDataset, self).__init__(root, transform, target_transform, loader)
  60.  
  61.     def __getitem__(self, index):
  62.         path, target = self.samples[index]
  63.         sample = self.loader(path)
  64.         if self.transform is not None:
  65.             sample = self.transform(sample)
  66.         if self.target_transform is not None:
  67.             target = self.target_transform(target)
  68.         return sample, target, path
  69.  
  70.  
  71. image_datasets = {x: MonaDataset(os.path.join(data_dir, x),
  72.                                           data_transforms[x])
  73.                   for x in ['train', 'test']}
  74. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  75.                                              shuffle=True, num_workers=4)
  76.               for x in ['train', 'test']}
  77.  
  78.  
  79.  
  80. targets = []
  81. with torch.no_grad():
  82.     for i, (inputs, classes, im_path) in enumerate(dataloaders['train']):
  83.         targets.append(classes.cpu().detach().numpy().tolist())
  84.  
  85. with torch.no_grad():
  86.     for i, (inputs, classes, im_path) in enumerate(dataloaders['test']):
  87.         targets.append(classes.cpu().detach().numpy().tolist())
  88.        
  89. flat_targets = [item for sublist in targets for item in sublist]
  90. flat_targets = torch.Tensor(flat_targets)
  91. print(flat_targets)
  92. class_sample_count = torch.tensor([(flat_targets == t).sum() for t in torch.unique(flat_targets, sorted=True)])
  93. print(class_sample_count)
  94. weight = 1. / class_sample_count.float()
  95. print(weight)
  96. samples_weight = torch.tensor([weight[t] for t in flat_targets])
  97. sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
  98.  
  99.  
  100. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  101.                                              shuffle=True, num_workers=4, sampler = sampler)
  102.               for x in ['train', 'test']}
  103. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
  104.  
  105.  
  106. class_names = image_datasets['train'].classes
  107. -------------------------------------------------------------
  108.  
  109. tensor([ 21, 136, 260,  44, 103,  57, 152,  28,  63])
  110. tensor([0.0476, 0.0074, 0.0038, 0.0227, 0.0097, 0.0175, 0.0066, 0.0357, 0.0159])
  111.  
  112. ---------------------------------------------------------------------------
  113. RuntimeError                              Traceback (most recent call last)
  114. <ipython-input-257-8e92e1fd5cb3> in <module>()
  115.      95 weight = 1. / class_sample_count.float()
  116.      96 print(weight)
  117. ---> 97 samples_weight = torch.tensor([weight[t] for t in flat_targets])
  118.      98 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
  119.      99
  120.  
  121. <ipython-input-257-8e92e1fd5cb3> in <listcomp>(.0)
  122.      95 weight = 1. / class_sample_count.float()
  123.      96 print(weight)
  124. ---> 97 samples_weight = torch.tensor([weight[t] for t in flat_targets])
  125.      98 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
  126.      99
  127.  
  128. RuntimeError: tensors used as indices must be long or byte tensors
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement