Advertisement
r4m0n

BcolzArrayIterator

Jan 20th, 2017
527
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.26 KB | None | 0 0
  1. class BcolzArrayIterator(object):
  2.     def __init__(self, X, y=None, w=None, batch_size=32, shuffle=False, seed=None):
  3.         if y is not None and len(X) != len(y):
  4.             raise ValueError('X (features) and y (labels) '
  5.                              'should have the same length. '
  6.                              'Found: X.shape = %s, y.shape = %s' % (X.shape, y.shape))
  7.         if w is not None and len(X) != len(w):
  8.             raise ValueError('X (features) and w (weights) '
  9.                              'should have the same length. '
  10.                              'Found: X.shape = %s, w.shape = %s' % (X.shape, w.shape))
  11.         if batch_size % X.chunklen != 0:
  12.             raise ValueError('batch_size needs to be a multiple of X.chunklen')
  13.         self.chunks_per_batch = batch_size // X.chunklen
  14.         self.X = X
  15.         if y is not None:
  16.             self.y = y[:]
  17.         else:
  18.             self.y = None
  19.         if w is not None:
  20.             self.w = w[:]
  21.         else:
  22.             self.w = None
  23.         self.N = X.shape[0]
  24.         self.batch_size = batch_size
  25.         self.batch_index = 0
  26.         self.total_batches_seen = 0
  27.         self.lock = threading.Lock()
  28.         self.shuffle = shuffle
  29.         self.seed = seed
  30.        
  31.     def reset(self):
  32.         self.batch_index = 0
  33.        
  34.     def next(self):
  35.         with self.lock:
  36.             if self.batch_index == 0:
  37.                 if self.seed is not None:
  38.                     np.random.seed(self.seed + self.total_batches_seen)
  39.                 if self.shuffle:
  40.                     self.index_array = np.random.permutation(self.X.nchunks + 1)
  41.                 else:
  42.                     self.index_array = np.arange(self.X.nchunks + 1)
  43.  
  44.             batches_x = []
  45.             batches_y = []
  46.             batches_w = []
  47.             for i in range(self.chunks_per_batch):
  48.                 current_index = self.index_array[self.batch_index]
  49.                 if current_index == self.X.nchunks:
  50.                     batches_x.append(self.X.leftover_array[:self.X.leftover_elements])
  51.                     current_batch_size = self.X.leftover_elements
  52.                 else:
  53.                     batches_x.append(self.X.chunks[current_index][:])
  54.                     current_batch_size = self.X.chunklen
  55.                 self.batch_index += 1
  56.                 self.total_batches_seen += 1
  57.                 if not self.y is None:
  58.                     batches_y.append(self.y[current_index * self.X.chunklen: current_index * self.X.chunklen + current_batch_size])
  59.                 if not self.w is None:
  60.                     batches_w.append(self.w[current_index * self.X.chunklen: current_index * self.X.chunklen + current_batch_size])
  61.                 if self.batch_index >= len(self.index_array):
  62.                     self.batch_index = 0
  63.                     break
  64.            
  65.         batch_x = np.concatenate(batches_x)
  66.         if self.y is None:
  67.             return batch_x
  68.        
  69.         batch_y = np.concatenate(batches_y)
  70.         if self.w is None:
  71.             return batch_x, batch_y
  72.        
  73.         batch_w = np.concatenate(batches_w)
  74.         return batch_x, batch_y, batch_w
  75.  
  76.     def __iter__(self):
  77.         return self
  78.  
  79.     def __next__(self, *args, **kwargs):
  80.         return self.next(*args, **kwargs)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement