Guest User

Untitled

a guest
Nov 24th, 2017
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.37 KB | None | 0 0
  1. import numpy as np
  2. import pickle
  3.  
  4. class ImageDataset():
  5.  
  6. def __init__(self, datasets, transform=[], type='train', resize = False):
  7. super(ImageDataset, self).__init__()
  8.  
  9. with open(datasets, mode='rb') as f:
  10. datasets = pickle.load(f)
  11.  
  12.  
  13. self.images = datasets['features']
  14. self.labels = datasets['labels']
  15. self.transform = transform
  16. self.type = type
  17. self.resize = resize
  18.  
  19. def get_train_item(self,index):
  20. image = self.images[index]
  21. label = self.labels[index]
  22. for t in self.transform:
  23. image = t(image)
  24.  
  25. return image, label, index
  26.  
  27. def get_test_item(self,index):
  28. image = self.images[index]
  29. label = self.labels[index]
  30. for t in self.transform:
  31. image = t(image)
  32.  
  33. return image, index
  34.  
  35.  
  36. def __getitem__(self, index):
  37.  
  38. if self.type=='train': return self.get_train_item(index)
  39. if self.type=='test': return self.get_test_item (index)
  40.  
  41. def __len__(self):
  42. return len(self.images)
  43.  
  44.  
  45. class DataLoader(object):
  46. """
  47. Data loader. Combines a dataset and a sampler, and provides
  48. single- or multi-process iterators over the dataset.
  49.  
  50. Arguments:
  51. dataset (Dataset): dataset from which to load the data.
  52. batch_size (int, optional): how many samples per batch to load
  53. (default: 1).
  54.  
  55. sampler (Sampler, optional): defines the strategy to draw samples from
  56. the dataset. If specified, ``shuffle`` must be False.
  57. batch_sampler (Sampler, optional): like sampler, but returns a batch of
  58. indices at a time. Mutually exclusive with batch_size, shuffle,
  59. sampler, and drop_last.
  60. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
  61. if the dataset size is not divisible by the batch size. If ``False`` and
  62. the size of dataset is not divisible by the batch size, then the last batch
  63. will be smaller. (default: False)
  64. """
  65.  
  66. def __init__(self, dataset, batch_size=1, drop_last=False):
  67. self.dataset = dataset
  68. self.batch_size = batch_size
  69.  
  70. self.collate_fn = []
  71.  
  72. sampler = RandomSampler(dataset)
  73.  
  74. batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  75.  
  76. self.sampler = sampler
  77. self.batch_sampler = batch_sampler
  78.  
  79. def __iter__(self):
  80. return DataLoaderIter(self)
  81.  
  82. def __len__(self):
  83. return len(self.batch_sampler)
  84.  
  85.  
  86. class BatchSampler(object):
  87. """Wraps another sampler to yield a mini-batch of indices.
  88. Args:
  89. sampler (Sampler): Base sampler.
  90. batch_size (int): Size of mini-batch.
  91. drop_last (bool): If ``True``, the sampler will drop the last batch if
  92. its size would be less than ``batch_size``
  93. Example:
  94. >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
  95. [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
  96. >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
  97. [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
  98. """
  99.  
  100. def __init__(self, sampler, batch_size, drop_last):
  101. self.sampler = sampler
  102. self.batch_size = batch_size
  103. self.drop_last = drop_last
  104.  
  105. def __iter__(self):
  106. batch = []
  107. for idx in self.sampler:
  108. batch.append(idx)
  109. if len(batch) == self.batch_size:
  110. yield batch
  111. batch = []
  112. if len(batch) > 0 and not self.drop_last:
  113. yield batch
  114.  
  115. def __len__(self):
  116. if self.drop_last:
  117. return len(self.sampler) // self.batch_size
  118. else:
  119. return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  120.  
  121.  
  122. class RandomSampler(object):
  123. """Samples elements randomly, without replacement.
  124. Arguments:
  125. data_source (Dataset): dataset to sample from
  126. """
  127.  
  128. def __init__(self, data_source):
  129. self.data_source = data_source
  130.  
  131. def __iter__(self):
  132. return iter(np.random.permutation(len(self.data_source)))
  133.  
  134. def __len__(self):
  135. return len(self.data_source)
  136.  
  137.  
  138. class DataLoaderIter(object):
  139. "Iterates once over the DataLoader's dataset, as specified by the sampler"
  140.  
  141. def __init__(self, loader):
  142. self.dataset = loader.dataset
  143. self.collate_fn = loader.collate_fn
  144. self.batch_sampler = loader.batch_sampler
  145. self.sample_iter = iter(self.batch_sampler)
  146.  
  147.  
  148. def __len__(self):
  149. return len(self.batch_sampler)
  150.  
  151. def __next__(self):
  152. indices = next(self.sample_iter) # may raise StopIteration
  153. batch = [self.dataset[i] for i in indices]
  154. batch = np.transpose(batch,(1,0))
  155. return batch
  156.  
  157. def __iter__(self):
  158. return self
  159.  
  160.  
  161. def __getstate__(self):
  162. # TODO: add limited pickling support for sharing an iterator
  163. # across multiple threads for HOGWILD.
  164. # Probably the best way to do this is by moving the sample pushing
  165. # to a separate thread and then just sharing the data queue
  166. # but signalling the end is tricky without a non-blocking API
  167. raise NotImplementedError("DataLoaderIterator cannot be pickled")
  168.  
  169.  
  170. if __name__ == "__main__":
  171.  
  172. dataset = ImageDataset(datasets="data/train.p")
  173. train_loader = DataLoader(
  174. dataset,
  175. batch_size=4,
  176. )
  177.  
  178. for i, (images,labels,indices) in enumerate(train_loader, 0):
  179. print('i=%d: ' % (i))
Add Comment
Please, Sign In to add comment