SHARE
TWEET

Untitled

a guest Oct 23rd, 2019 66 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import cv2
  3. from tensorflow.keras.utils import Sequence
  4.  
  5.  
  6. class DataGenerator(Sequence):
  7.     """Generates data for Keras
  8.     Sequence based data generator. Suitable for building data generator for training and prediction.
  9.     """
  10.     def __init__(self, list_IDs, labels, image_path, mask_path,
  11.                  to_fit=True, batch_size=32, dim=(256, 256),
  12.                  n_channels=1, n_classes=10, shuffle=True):
  13.         """Initialization
  14.  
  15.         :param list_IDs: list of all 'label' ids to use in the generator
  16.         :param labels: list of image labels (file names)
  17.         :param image_path: path to images location
  18.         :param mask_path: path to masks location
  19.         :param to_fit: True to return X and y, False to return X only
  20.         :param batch_size: batch size at each iteration
  21.         :param dim: tuple indicating image dimension
  22.         :param n_channels: number of image channels
  23.         :param n_classes: number of output masks
  24.         :param shuffle: True to shuffle label indexes after every epoch
  25.         """
  26.         self.list_IDs = list_IDs
  27.         self.labels = labels
  28.         self.image_path = image_path
  29.         self.mask_path = mask_path
  30.         self.to_fit = to_fit
  31.         self.batch_size = batch_size
  32.         self.dim = dim
  33.         self.n_channels = n_channels
  34.         self.n_classes = n_classes
  35.         self.shuffle = shuffle
  36.         self.on_epoch_end()
  37.  
  38.     def __len__(self):
  39.         """Denotes the number of batches per epoch
  40.  
  41.         :return: number of batches per epoch
  42.         """
  43.         return int(np.floor(len(self.list_IDs) / self.batch_size))
  44.  
  45.     def __getitem__(self, index):
  46.         """Generate one batch of data
  47.  
  48.         :param index: index of the batch
  49.         :return: X and y when fitting. X only when predicting
  50.         """
  51.         # Generate indexes of the batch
  52.         indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
  53.  
  54.         # Find list of IDs
  55.         list_IDs_temp = [self.list_IDs[k] for k in indexes]
  56.  
  57.         # Generate data
  58.         X = self._generate_X(list_IDs_temp)
  59.  
  60.         if self.to_fit:
  61.             y = self._generate_y(list_IDs_temp)
  62.             return X, y
  63.         else:
  64.             return X
  65.  
  66.     def on_epoch_end(self):
  67.         """Updates indexes after each epoch
  68.  
  69.         """
  70.         self.indexes = np.arange(len(self.list_IDs))
  71.         if self.shuffle == True:
  72.             np.random.shuffle(self.indexes)
  73.  
  74.     def _generate_X(self, list_IDs_temp):
  75.         """Generates data containing batch_size images
  76.  
  77.         :param list_IDs_temp: list of label ids to load
  78.         :return: batch of images
  79.         """
  80.         # Initialization
  81.         X = np.empty((self.batch_size, *self.dim, self.n_channels))
  82.  
  83.         # Generate data
  84.         for i, ID in enumerate(list_IDs_temp):
  85.             # Store sample
  86.             X[i,] = self._load_grayscale_image(self.image_path + self.labels[ID])
  87.  
  88.         return X
  89.  
  90.     def _generate_y(self, list_IDs_temp):
  91.         """Generates data containing batch_size masks
  92.  
  93.         :param list_IDs_temp: list of label ids to load
  94.         :return: batch if masks
  95.         """
  96.         y = np.empty((self.batch_size, *self.dim), dtype=int)
  97.  
  98.         # Generate data
  99.         for i, ID in enumerate(list_IDs_temp):
  100.             # Store sample
  101.             y[i,] = self._load_grayscale_image(self.mask_path + self.labels[ID])
  102.  
  103.         return y
  104.  
  105.     def _load_grayscale_image(self, image_path):
  106.         """Load grayscale image
  107.  
  108.         :param image_path: path to image to load
  109.         :return: loaded image
  110.         """
  111.         img = cv2.imread(image_path)
  112.         img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  113.         img = img / 255
  114.         return img
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top