Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.60 KB | None | 0 0
  1. import numpy as np
  2. import cv2
  3. from tensorflow.keras.utils import Sequence
  4.  
  5.  
  6. class DataGenerator(Sequence):
  7. """Generates data for Keras
  8. Sequence based data generator. Suitable for building data generator for training and prediction.
  9. """
  10. def __init__(self, list_IDs, labels, image_path, mask_path,
  11. to_fit=True, batch_size=32, dim=(256, 256),
  12. n_channels=1, n_classes=10, shuffle=True):
  13. """Initialization
  14.  
  15. :param list_IDs: list of all 'label' ids to use in the generator
  16. :param labels: list of image labels (file names)
  17. :param image_path: path to images location
  18. :param mask_path: path to masks location
  19. :param to_fit: True to return X and y, False to return X only
  20. :param batch_size: batch size at each iteration
  21. :param dim: tuple indicating image dimension
  22. :param n_channels: number of image channels
  23. :param n_classes: number of output masks
  24. :param shuffle: True to shuffle label indexes after every epoch
  25. """
  26. self.list_IDs = list_IDs
  27. self.labels = labels
  28. self.image_path = image_path
  29. self.mask_path = mask_path
  30. self.to_fit = to_fit
  31. self.batch_size = batch_size
  32. self.dim = dim
  33. self.n_channels = n_channels
  34. self.n_classes = n_classes
  35. self.shuffle = shuffle
  36. self.on_epoch_end()
  37.  
  38. def __len__(self):
  39. """Denotes the number of batches per epoch
  40.  
  41. :return: number of batches per epoch
  42. """
  43. return int(np.floor(len(self.list_IDs) / self.batch_size))
  44.  
  45. def __getitem__(self, index):
  46. """Generate one batch of data
  47.  
  48. :param index: index of the batch
  49. :return: X and y when fitting. X only when predicting
  50. """
  51. # Generate indexes of the batch
  52. indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
  53.  
  54. # Find list of IDs
  55. list_IDs_temp = [self.list_IDs[k] for k in indexes]
  56.  
  57. # Generate data
  58. X = self._generate_X(list_IDs_temp)
  59.  
  60. if self.to_fit:
  61. y = self._generate_y(list_IDs_temp)
  62. return X, y
  63. else:
  64. return X
  65.  
  66. def on_epoch_end(self):
  67. """Updates indexes after each epoch
  68.  
  69. """
  70. self.indexes = np.arange(len(self.list_IDs))
  71. if self.shuffle == True:
  72. np.random.shuffle(self.indexes)
  73.  
  74. def _generate_X(self, list_IDs_temp):
  75. """Generates data containing batch_size images
  76.  
  77. :param list_IDs_temp: list of label ids to load
  78. :return: batch of images
  79. """
  80. # Initialization
  81. X = np.empty((self.batch_size, *self.dim, self.n_channels))
  82.  
  83. # Generate data
  84. for i, ID in enumerate(list_IDs_temp):
  85. # Store sample
  86. X[i,] = self._load_grayscale_image(self.image_path + self.labels[ID])
  87.  
  88. return X
  89.  
  90. def _generate_y(self, list_IDs_temp):
  91. """Generates data containing batch_size masks
  92.  
  93. :param list_IDs_temp: list of label ids to load
  94. :return: batch if masks
  95. """
  96. y = np.empty((self.batch_size, *self.dim), dtype=int)
  97.  
  98. # Generate data
  99. for i, ID in enumerate(list_IDs_temp):
  100. # Store sample
  101. y[i,] = self._load_grayscale_image(self.mask_path + self.labels[ID])
  102.  
  103. return y
  104.  
  105. def _load_grayscale_image(self, image_path):
  106. """Load grayscale image
  107.  
  108. :param image_path: path to image to load
  109. :return: loaded image
  110. """
  111. img = cv2.imread(image_path)
  112. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  113. img = img / 255
  114. return img
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement