Guest User

Untitled

a guest
Dec 16th, 2017
215
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.35 KB | None | 0 0
  1. import multiprocessing.pool
  2. import os
  3. from functools import partial
  4. import h5py
  5. import numpy as np
  6. from keras import backend as K
  7. from keras.preprocessing.image import Iterator, _count_valid_files_in_directory, _list_valid_filenames_in_directory,
  8. ImageDataGenerator
  9.  
  10.  
  11. class MatFileIterGenerator(object):
  12. def __init__(self):
  13. self.image_data_generator = ImageDataGenerator()
  14.  
  15. def flow_from_directory(self, directory, variable,
  16. target_size=(256, 256), classes=None, class_mode='categorical',
  17. batch_size=32, shuffle=True, seed=None,
  18. follow_links=False,
  19. interpolation='nearest'):
  20. return MatFilesIterator(
  21. directory, variable, self.image_data_generator,
  22. classes=classes, class_mode=class_mode,
  23. batch_size=batch_size, shuffle=shuffle, seed=seed,
  24. follow_links=follow_links,
  25. interpolation=interpolation)
  26.  
  27.  
  28. class MatFilesIterator(Iterator):
  29.  
  30. def __init__(self, directory, variable, image_data_generator, classes=None, class_mode="categorical",
  31. batch_size=32, shuffle=True, seed=None, interpolation='nearest', follow_links=False):
  32.  
  33. self.variable = variable
  34. self.directory = directory
  35. self.image_data_generator = image_data_generator
  36. self.data_format = K.image_data_format()
  37. self.classes = classes
  38.  
  39. if class_mode not in {'categorical', 'binary', 'sparse',
  40. 'input', None}:
  41. raise ValueError('Invalid class_mode:', class_mode,
  42. '; expected one of "categorical", '
  43. '"binary", "sparse", "input"'
  44. ' or None.')
  45. self.class_mode = class_mode
  46. self.interpolation = interpolation
  47.  
  48. white_list_formats = {"mat"}
  49.  
  50. # first, count the number of samples and classes
  51. self.samples = 0
  52.  
  53. if not classes:
  54. classes = []
  55. for subdir in sorted(os.listdir(directory)):
  56. if os.path.isdir(os.path.join(directory, subdir)):
  57. classes.append(subdir)
  58. self.num_classes = len(classes)
  59. self.class_indices = dict(zip(classes, range(len(classes))))
  60.  
  61. pool = multiprocessing.pool.ThreadPool()
  62. function_partial = partial(_count_valid_files_in_directory,
  63. white_list_formats=white_list_formats,
  64. follow_links=follow_links)
  65. self.samples = sum(pool.map(function_partial,
  66. (os.path.join(directory, subdir)
  67. for subdir in classes)))
  68.  
  69. print('Found %d files belonging to %d classes.' % (self.samples, self.num_classes))
  70.  
  71. # second, build an index of the images in the different class subfolders
  72. results = []
  73.  
  74. self.filenames = []
  75. self.classes = np.zeros((self.samples,), dtype='int32')
  76. i = 0
  77. for dirpath in (os.path.join(directory, subdir) for subdir in classes):
  78. results.append(pool.apply_async(_list_valid_filenames_in_directory,
  79. (dirpath, white_list_formats,
  80. self.class_indices, follow_links)))
  81. for res in results:
  82. classes, filenames = res.get()
  83. self.classes[i:i + len(classes)] = classes
  84. self.filenames += filenames
  85. i += len(classes)
  86. pool.close()
  87. pool.join()
  88.  
  89. super(MatFilesIterator, self).__init__(self.samples, batch_size, shuffle, seed)
  90.  
  91. def _get_batches_of_transformed_samples(self, index_array):
  92.  
  93. # The script fails here with mentioned error
  94. # 60 is the mentioned constant row numbers
  95. batch_x = np.zeros(tuple([len(index_array)] + [60]), dtype=K.floatx())
  96.  
  97. # build batch of numpy data
  98. for i, j in enumerate(index_array):
  99. fname = self.filenames[j]
  100. arr = np.array(h5py.File(os.path.join(self.directory, fname), "r").get(self.variable))
  101.  
  102. arr = self.image_data_generator.random_transform(arr.astype(K.floatx()))
  103. arr = self.image_data_generator.standardize(arr)
  104. batch_x[i] = arr
  105.  
  106. # build batch of labels
  107. if self.class_mode == 'input':
  108. batch_y = batch_x.copy()
  109. elif self.class_mode == 'sparse':
  110. batch_y = self.classes[index_array]
  111. elif self.class_mode == 'binary':
  112. batch_y = self.classes[index_array].astype(K.floatx())
  113. elif self.class_mode == 'categorical':
  114. batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx())
  115. for i, label in enumerate(self.classes[index_array]):
  116. batch_y[i, label] = 1.
  117. else:
  118. return batch_x
  119. return batch_x, batch_y
  120.  
  121. def next(self):
  122. """For python 2.x.
  123.  
  124. # Returns
  125. The next batch.
  126. """
  127. with self.lock:
  128. index_array = next(self.index_generator)
  129. # The transformation of images is not under thread lock
  130. # so it can be done in parallel
  131. return self._get_batches_of_transformed_samples(index_array)
  132.  
  133. from keras.layers import Activation, Conv2D, MaxPooling2D, GlobalMaxPooling2D, Dense, Dropout
  134. from keras.models import Sequential
  135.  
  136. from matfileiter import MatFileIterGenerator
  137.  
  138.  
  139. class CNN2D:
  140.  
  141. def __init__(self):
  142. self._model = Sequential()
  143.  
  144. self._model.add(Conv2D(60, (3, 3), input_shape=(60, None , 1)))
  145. self._model.add(Activation("relu"))
  146. self._model.add(MaxPooling2D(pool_size=(3, 3)))
  147.  
  148. self._model.add(Conv2D(60, (3, 3)))
  149. self._model.add(Activation("relu"))
  150. self._model.add(MaxPooling2D(pool_size=(3, 3)))
  151.  
  152. self._model.add(Conv2D(120, (3, 3)))
  153. self._model.add(Activation("relu"))
  154. self._model.add(MaxPooling2D(pool_size=(3, 3)))
  155.  
  156. self._model.add(GlobalMaxPooling2D())
  157. self._model.add(Dense(120))
  158. self._model.add(Activation('relu'))
  159. self._model.add(Dropout(0.2))
  160. self._model.add(Dense(1))
  161. self._model.add(Activation('sigmoid'))
  162.  
  163. self._model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
  164.  
  165. def createGenerators(self, train_path, variable, test_path, batch_size):
  166. train_datagen = MatFileIterGenerator()
  167. self._train_generator = train_datagen.flow_from_directory(
  168. train_path,
  169. variable,
  170. shuffle=True,
  171. batch_size=batch_size,
  172. class_mode="binary")
  173.  
  174. test_datagen = MatFileIterGenerator()
  175. self._test_generator = test_datagen.flow_from_directory(
  176. test_path,
  177. variable,
  178. shuffle=True,
  179. batch_size=batch_size,
  180. class_mode="binary")
  181.  
  182. def train_model(self, batch_size):
  183. self._model.fit_generator(
  184. self._train_generator,
  185. steps_per_epoch=2000 // batch_size,
  186. epochs=50,
  187. validation_data=self._test_generator,
  188. validation_steps=800 // batch_size,
  189. workers=2,
  190. use_multiprocessing=True)
  191.  
  192.  
  193. if __name__ == '__main__':
  194. cnn = CNN2D()
  195. cnn.createGenerators("/home/wilson/Documents/Data/_train_mat", "coeffs",
  196. "/home/wilosn/Documents/Data/_eval_mat", 20)
  197. cnn.train_model(20)
Add Comment
Please, Sign In to add comment