Advertisement
Guest User

Untitled

a guest
Sep 21st, 2018
305
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.13 KB | None | 0 0
  1. import tensorflow as tf
  2. from keras.layers import *
  3. from keras.models import Model
  4. from keras.utils import plot_model
  5.  
  6. import numpy as np
  7.  
  8. import matplotlib.pyplot as plt
  9. import matplotlib.image as mpimg
  10. import skimage
  11. from skimage import transform
  12.  
  13.  
  14. data_path = 'pictures/'
  15. img_width = 1400
  16. img_height = 1400
  17. img_depth = 3
  18.  
  19.  
  20. # see: https://keras.io/preprocessing/image/
  21. def dataAugmentation(dataToAugment):
  22. arrayToFill = []
  23.  
  24. # faster computation with values between 0 and 1 ?
  25. dataToAugment = np.divide(dataToAugment, 255.)
  26.  
  27. # TODO: switch from RGB channels to CbCrY
  28.  
  29. # adding the normal images (8)
  30. for i in range(len(dataToAugment)):
  31. arrayToFill.append(dataToAugment[i])
  32. # vertical axis flip (-> 16)
  33. for i in range(len(arrayToFill)):
  34. arrayToFill.append(np.fliplr(arrayToFill[i]))
  35. # horizontal axis flip (-> 32)
  36. for i in range(len(arrayToFill)):
  37. arrayToFill.append(np.flipud(arrayToFill[i]))
  38.  
  39. # downsizing by scale of 4 (-> 64 images of 350x350x3)
  40. for i in range(len(arrayToFill)):
  41. arrayToFill.append(skimage.transform.resize(
  42. arrayToFill[i],
  43. (img_width/4, img_height/4),
  44. mode='reflect',
  45. anti_aliasing=True))
  46.  
  47. # # Sanity check: display the images
  48. # plt.figure(figsize=(10, 10))
  49. # for i in range(64):
  50. # plt.subplot(8, 8, i + 1)
  51. # plt.imshow(arrayToFill[i], cmap=plt.cm.binary)
  52. # plt.show()
  53.  
  54. return arrayToFill
  55.  
  56.  
  57. def setUpImages():
  58.  
  59. # Setting up paths
  60. path1 = data_path + 'Moi.jpg'
  61. path2 = data_path + 'ASaucerfulOfSecrets.jpg'
  62. path3 = data_path + 'AtomHeartMother.jpg'
  63. path4 = data_path + 'Animals.jpg'
  64. path5 = data_path + 'DivisionBell.jpg' # validator
  65. path6 = data_path + 'lighter.jpg'
  66. path7 = data_path + 'Meddle.jpg' # validator
  67. path8 = data_path + 'ObscuredByClouds.jpg' # validator
  68. path9 = data_path + 'TheDarkSideOfTheMoon.jpg'
  69. path10 = data_path + 'TheWall.jpg'
  70. path11 = data_path + 'WishYouWereHere.jpg'
  71.  
  72. # Extracting images (1400x1400)
  73. train = [mpimg.imread(path1),
  74. mpimg.imread(path2),
  75. mpimg.imread(path3),
  76. mpimg.imread(path4),
  77. mpimg.imread(path6),
  78. mpimg.imread(path9),
  79. mpimg.imread(path10),
  80. mpimg.imread(path11)]
  81. finalTest = [mpimg.imread(path5),
  82. mpimg.imread(path8),
  83. mpimg.imread(path7)]
  84.  
  85. # Augmenting data
  86. trainData = dataAugmentation(train)
  87. testData = dataAugmentation(finalTest)
  88.  
  89. setUpData(trainData, testData)
  90.  
  91.  
  92. def setUpData(trainData, testData):
  93.  
  94. print(type(trainData)) # <class 'list'>
  95. print(len(trainData)) # 64
  96. print(type(trainData[0])) # <class 'numpy.ndarray'>
  97. print(trainData[0].shape) # (1400, 1400, 3)
  98. # conversion to Numpy Arrays of shape (num_img, width, height, channels=3)
  99. trainData = np.array(trainData)
  100. testData = np.array(testData)
  101. print(type(trainData)) # <class 'numpy.ndarray'>
  102. print(trainData.shape) # (64,)
  103.  
  104. # TODO: substract mean of all images to all images
  105.  
  106. # # 16 images with 3 flattened channels
  107. # trainData = trainData.flatten().reshape(trainSize, img_width * img_height, 3)
  108. # testData = testData.flatten().reshape(testSize, img_width * img_height, 3)
  109. #
  110. # # Sanity check: displaying image
  111. # plt.imshow(trainData[2].reshape(img_width, img_height, img_depth))
  112. # plt.show()
  113. #
  114. # # Sanity check: displaying flipped image
  115. # plt.imshow(trainData[2+8].reshape(img_width, img_height, img_depth))
  116. # plt.show()
  117.  
  118. # Separating the training data
  119. halved = np.split(trainData, 2)
  120. validateData = halved[0] # First half is the unaltered data
  121. trainingData = halved[1] # Second half is the deteriorated data
  122.  
  123. # Separating the testing data
  124. halved = np.split(testData, 2)
  125. validateTestData = halved[0] # First half is the unaltered data
  126. trainingTestData = halved[1] # Second half is the deteriorated data
  127.  
  128. print(type(validateData)) # <class 'numpy.ndarray'>
  129. print(validateData.shape) # (32,)
  130. print(type(validateData[0].shape)) # <class 'tuple'>
  131. print(validateData[0].shape) # (1400, 1400, 3)
  132.  
  133. # # Sanity check: display four images (2x HR/LR)
  134. # plt.figure(figsize=(10, 10))
  135. # for i in range(2):
  136. # plt.subplot(2, 2, i + 1)
  137. # plt.imshow(validateData[i], cmap=plt.cm.binary)
  138. # for i in range(2):
  139. # plt.subplot(2, 2, i + 1 + 2)
  140. # plt.imshow(trainingData[i], cmap=plt.cm.binary)
  141. # plt.show()
  142.  
  143. setUpModel(validateData, trainingData, validateTestData, trainingTestData)
  144.  
  145.  
  146. def setUpModel(validateData, trainingData, validateTestData, trainingTestData):
  147.  
  148. # # exemple de merge de deux networks: merge = concatenate([network1, network2])
  149. # # exemple de deux inputs pour un seul model: model = Model(inputs=[visible1, visible2], outputs=output)
  150.  
  151. filters = 256
  152. kernel_size = 3
  153. strides = 1
  154. factor = 4 # the factor of upscaling
  155.  
  156. inputLayer = Input(shape=(img_height/factor, img_width/factor, img_depth))
  157. conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer)
  158.  
  159. res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv1)
  160. act = ReLU()(res)
  161. res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
  162. res_rec = Add()([conv1, res])
  163.  
  164. for i in range(15): # 16-1
  165. res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
  166. act = ReLU()(res1)
  167. res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
  168. res_rec = Add()([res_rec, res2])
  169.  
  170. conv2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
  171. a = Add()([conv1, conv2])
  172. up = UpSampling2D(size=4)(a)
  173. outputLayer = Conv2D(filters=3,
  174. kernel_size=1,
  175. strides=1,
  176. padding='same')(up)
  177.  
  178. model = Model(inputs=inputLayer, outputs=outputLayer)
  179.  
  180. # Sanity checks
  181. print(model.summary())
  182. plot_model(model, to_file='CNN_graph.png')
  183.  
  184. train(model, validateData, trainingData, validateTestData, trainingTestData)
  185.  
  186.  
  187. def train(model, validateData, trainingData, validateTestData, trainingTestData):
  188. model.compile(optimizer=tf.train.AdamOptimizer(),
  189. loss='sparse_categorical_crossentropy', # TODO: Customize loss function?
  190. metrics=['accuracy'])
  191. # TODO: possibly need to be LISTS instead of np.array ?
  192. model.fit(trainingData,
  193. validateData,
  194. epochs=5,
  195. verbose=2,
  196. batch_size=4) # 32 images-> 8 batches of 4 TODO is it multi-fold testing?
  197.  
  198. # Now use the TEST dataset to calculate performance
  199. test_loss, test_acc = model.evaluate(trainingTestData, validateTestData)
  200. print('Test accuracy:', test_acc)
  201.  
  202.  
  203.  
  204.  
  205. ###########################
  206. # PREDICTIONS #
  207. ###########################
  208.  
  209. # Trying to make predictions on a single image
  210. predictions = model.predict(trainingTestData)
  211.  
  212. # "model.predict" works in batches, so extracting a single prediction:
  213. img = trainingTestData[0] # Grab an image from the test dataset
  214. img = (np.expand_dims(img, 0)) # Add the image to a batch where it's the only member.
  215. predictions_single = model.predict(img) # returns a list of lists, one for each image in the batch of data
  216. print(predictions_single)
  217.  
  218. ###########################
  219. # DRAWINGS #
  220. ###########################
  221.  
  222. # def plot_image(i, predictions_array, true_label, img):
  223. # predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
  224. # plt.grid(False)
  225. # plt.xticks([])
  226. # plt.yticks([])
  227. #
  228. # plt.imshow(img, cmap=plt.cm.binary)
  229. #
  230. # predicted_label = np.argmax(predictions_array)
  231. # if predicted_label == true_label:
  232. # color = 'blue'
  233. # else:
  234. # color = 'red'
  235. #
  236. # plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
  237. # 100 * np.max(predictions_array),
  238. # class_names[true_label]),
  239. # color=color)
  240. #
  241. # def plot_value_array(i, predictions_array, true_label):
  242. # predictions_array, true_label = predictions_array[i], true_label[i]
  243. # plt.grid(False)
  244. # plt.xticks([])
  245. # plt.yticks([])
  246. # thisplot = plt.bar(range(10), predictions_array, color="#777777")
  247. # plt.ylim([0, 1])
  248. # predicted_label = np.argmax(predictions_array)
  249. # plt.xticks(range(10)) # adding the class-index below prediction graph
  250. #
  251. # thisplot[predicted_label].set_color('red')
  252. # thisplot[true_label].set_color('blue')
  253. #
  254. # # def draw_prediction(index):
  255. # # plt.figure(figsize=(6, 3))
  256. # # plt.subplot(1, 2, 1)
  257. # # plot_image(index, predictions, test_labels, test_images)
  258. # # plt.subplot(1, 2, 2)
  259. # # plot_value_array(index, predictions, test_labels)
  260. # # plt.show()
  261. # #
  262. # # To draw a single prediction
  263. # # draw_prediction(0)
  264. # # draw_prediction(12)
  265. #
  266. # # Plot the first X test images, their predicted label, and the true label
  267. # # Color correct predictions in blue, incorrect predictions in red
  268. # num_rows = 5
  269. # num_cols = 3
  270. # num_images = num_rows * num_cols
  271. # plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
  272.  
  273. # Adding a title to the plot
  274. plt.suptitle("Check it out!")
  275.  
  276. # for i in range(num_images):
  277. # plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
  278. # plot_image(i, predictions, test_labels, test_images)
  279. # plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
  280. # plot_value_array(i, predictions, test_labels)
  281. # plt.show()
  282.  
  283.  
  284. if __name__ == '__main__':
  285. setUpImages()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement