Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ######################################################################
- # Load Data
- # ---------
- #
- # We will use torchvision and torch.utils.data packages for loading the
- # data.
- #
- # The problem we're going to solve today is to train a model to classify
- # **ants** and **bees**. We have about 120 training images each for ants and bees.
- # There are 75 validation images for each class. Usually, this is a very
- # small dataset to generalize upon, if trained from scratch. Since we
- # are using transfer learning, we should be able to generalize reasonably
- # well.
- #
- # This dataset is a very small subset of imagenet.
- #
- # .. Note ::
- # Download the data from
- # `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
- # and extract it to the current directory.
- # Data augmentation and normalization for training
- # Just normalization for validation
- #using the weight sampling method from https://github.com/ptrblck/pytorch_misc/blob/master/weighted_sampling.py#L25
- data_transforms = {
- 'train': transforms.Compose([
- transforms.RandomResizedCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.RandomRotation(20),
- transforms.ColorJitter(0.3, 0.3, 0.3),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- # 'val': transforms.Compose([
- # transforms.Resize(256),
- # transforms.CenterCrop(224),
- # transforms.ToTensor(),
- # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- # ]),
- 'test': transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- }
- data_dir = "10folds/10fold_9"
- class MonaDataset(datasets.folder.ImageFolder):
- def __init__(self, root, transform=None, target_transform=None,
- loader=datasets.folder.default_loader):
- super(MonaDataset, self).__init__(root, transform, target_transform, loader)
- def __getitem__(self, index):
- path, target = self.samples[index]
- sample = self.loader(path)
- if self.transform is not None:
- sample = self.transform(sample)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return sample, target, path
- image_datasets = {x: MonaDataset(os.path.join(data_dir, x),
- data_transforms[x])
- for x in ['train', 'test']}
- dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
- shuffle=True, num_workers=4)
- for x in ['train', 'test']}
- targets = []
- with torch.no_grad():
- for i, (inputs, classes, im_path) in enumerate(dataloaders['train']):
- targets.append(classes.cpu().detach().numpy().tolist())
- with torch.no_grad():
- for i, (inputs, classes, im_path) in enumerate(dataloaders['test']):
- targets.append(classes.cpu().detach().numpy().tolist())
- flat_targets = [item for sublist in targets for item in sublist]
- flat_targets = torch.Tensor(flat_targets)
- print(flat_targets)
- class_sample_count = torch.tensor([(flat_targets == t).sum() for t in torch.unique(flat_targets, sorted=True)])
- print(class_sample_count)
- weight = 1. / class_sample_count.float()
- print(weight)
- samples_weight = torch.tensor([weight[t] for t in flat_targets])
- sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
- dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
- shuffle=True, num_workers=4, sampler = sampler)
- for x in ['train', 'test']}
- dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
- class_names = image_datasets['train'].classes
- -------------------------------------------------------------
- tensor([ 21, 136, 260, 44, 103, 57, 152, 28, 63])
- tensor([0.0476, 0.0074, 0.0038, 0.0227, 0.0097, 0.0175, 0.0066, 0.0357, 0.0159])
- ---------------------------------------------------------------------------
- RuntimeError Traceback (most recent call last)
- <ipython-input-257-8e92e1fd5cb3> in <module>()
- 95 weight = 1. / class_sample_count.float()
- 96 print(weight)
- ---> 97 samples_weight = torch.tensor([weight[t] for t in flat_targets])
- 98 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
- 99
- <ipython-input-257-8e92e1fd5cb3> in <listcomp>(.0)
- 95 weight = 1. / class_sample_count.float()
- 96 print(weight)
- ---> 97 samples_weight = torch.tensor([weight[t] for t in flat_targets])
- 98 sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
- 99
- RuntimeError: tensors used as indices must be long or byte tensors
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement