Advertisement
Guest User

Untitled

a guest
Nov 21st, 2017
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.57 KB | None | 0 0
  1. from keras import losses
  2. from keras import optimizers
  3. from keras.models import Sequential
  4. from keras.layers import Dense, Dropout, Activation, Flatten
  5. from keras.layers import Convolution2D, MaxPooling2D, Dense
  6.  
  7. from keras.utils import np_utils
  8.  
  9. from sklearn.preprocessing import MinMaxScaler
  10.  
  11. from rect import Rect
  12. from PIL import Image, ImageDraw
  13. from tqdm import tqdm
  14.  
  15. import matplotlib.pyplot as plt
  16. import numpy as np
  17. import math, random
  18. import utils, os
  19.  
  20. batch_size = 32
  21. epochs = 100
  22.  
  23. grid_size = 13
  24. cell_size = 32
  25. rects_in_cell = 5
  26. class_num = 43
  27.  
  28. cell_out_size = class_num + 5
  29. cell_vec_size = cell_out_size * rects_in_cell
  30.  
  31. input_width = grid_size * cell_size
  32. input_height = grid_size * cell_size
  33.  
  34. def loss_function(y_true, y_pred):
  35. return 0
  36.  
  37. def load_dataset(data_path):
  38. x = []
  39. y = []
  40. image_info = {}
  41. file_order = []
  42.  
  43. for f in tqdm(utils.list_files(data_path, 'ppm')):
  44. img = Image.open(f)
  45.  
  46. p, name = os.path.split(f)
  47. image_info[name] = [len(x), img.size]
  48. img = img.resize((input_width, input_height), Image.NEAREST)
  49.  
  50. file_order.append(name)
  51. x.append(np.asarray(img))
  52.  
  53. rects_map = {}
  54. rects = utils.read_description(os.path.join(data_path, 'gt.txt'))
  55. for r in rects:
  56. name = r[0]
  57. x1 = int(r[1])
  58. y1 = int(r[2])
  59. x2 = int(r[3])
  60. y2 = int(r[4])
  61. cl = int(r[5])
  62.  
  63. assert x1 < x2
  64. assert y1 < y2
  65.  
  66. info = image_info[name]
  67. w, h = info[1]
  68. x_mult = float(input_width) / float(w)
  69. y_mult = float(input_height) / float(h)
  70.  
  71. x1 = round(x1 * x_mult)
  72. x2 = round(x2 * x_mult)
  73. y1 = round(y1 * y_mult)
  74. y2 = round(y2 * y_mult)
  75.  
  76. assert x1 < x2
  77. assert y1 < y2
  78.  
  79. v = [Rect(x1, y1, x2 - x1, y2 - y1), cl]
  80. if name in rects_map:
  81. rects_map[name].append(v)
  82. else:
  83. rects_map[name] = [ v ]
  84.  
  85. for f in file_order:
  86. file_data = []
  87. for cy in range(grid_size):
  88. row = []
  89. for cx in range(grid_size):
  90. cell = [0.0] * cell_vec_size
  91.  
  92. # get coordinates of cell rect
  93. cell_rect = Rect(cx * cell_size,
  94. cy * cell_size,
  95. cell_size,
  96. cell_size)
  97.  
  98. # find intersected rects
  99. save_rects = []
  100. if f in rects_map:
  101. for r in rects_map[f]:
  102. if cell_rect.intersection(r[0]) != None:
  103. save_rects.append(r)
  104.  
  105. # write rects to output array
  106. #print('Save: {}, {}'.format(len(save_rects), f))
  107. for i in range(len(save_rects)):
  108. r = save_rects[i][0]
  109. c = save_rects[i][1]
  110.  
  111. offset = i * cell_out_size
  112. cell[offset] = r.x()
  113. cell[offset + 1] = r.y()
  114. cell[offset + 2] = r.width()
  115. cell[offset + 3] = r.height()
  116. cell[offset + 4] = 1.0
  117. cell[offset + 5 + cl] = 1.0
  118.  
  119. row.append(cell)
  120.  
  121. file_data.append(row)
  122.  
  123. y.append(file_data)
  124.  
  125. return (x, y)
  126.  
  127. print ("Prepare train data...")
  128. x_train, y_train = load_dataset('data/train')
  129.  
  130. print ("Prepare test data...")
  131. x_test, y_test = load_dataset('data/test')
  132.  
  133. # TEST ANY INPUT IMAGE
  134. # index = 78
  135. # img = Image.fromarray(x_train[index], 'RGB')
  136. # draw = ImageDraw.Draw(img)
  137. # for r in y_train[index]:
  138. # draw.rectangle(r[:-1], outline=(0, 255, 0))
  139. # img.show()
  140. # TEST END
  141.  
  142. input_shape = (input_width, input_height, 3)
  143.  
  144. # Use simple architecture: http://machinethink.net/blog/object-detection-with-yolo/
  145. model = Sequential()
  146. model.add(Convolution2D(16, kernel_size=(3, 3),
  147. activation='relu',
  148. strides=1,
  149. input_shape=input_shape))
  150.  
  151. model.add(MaxPooling2D(pool_size=(2, 2),
  152. strides=2))
  153.  
  154. model.add(Convolution2D(32, kernel_size=(3, 3),
  155. activation='relu',
  156. strides=1))
  157.  
  158. model.add(MaxPooling2D(pool_size=(2, 2),
  159. strides=2))
  160.  
  161. model.add(Convolution2D(64, kernel_size=(3, 3),
  162. activation='relu',
  163. strides=1))
  164.  
  165. model.add(MaxPooling2D(pool_size=(2, 2),
  166. strides=2))
  167.  
  168. model.add(Convolution2D(128, kernel_size=(3, 3),
  169. activation='relu',
  170. strides=1))
  171.  
  172. model.add(MaxPooling2D(pool_size=(2, 2),
  173. strides=2))
  174.  
  175. model.add(Convolution2D(256, kernel_size=(3, 3),
  176. activation='relu',
  177. strides=1))
  178.  
  179. model.add(MaxPooling2D(pool_size=(2, 2),
  180. strides=2))
  181.  
  182. model.add(Convolution2D(512, kernel_size=(3, 3),
  183. activation='relu',
  184. strides=1))
  185.  
  186. model.add(MaxPooling2D(pool_size=(2, 2),
  187. strides=1))
  188.  
  189. model.add(Convolution2D(1024, kernel_size=(3, 3),
  190. activation='relu',
  191. strides=1))
  192.  
  193. model.add(Convolution2D(1024, kernel_size=(3, 3),
  194. activation='relu',
  195. strides=1))
  196.  
  197. model.add(Convolution2D(cell_vec_size, kernel_size=(1, 1),
  198. activation='relu',
  199. strides=1))
  200.  
  201. # model.add(Dense(256))
  202. # model.add(Dense(100))
  203.  
  204.  
  205. model.compile(loss='mse',
  206. optimizer=optimizers.Adadelta(),
  207. metrics=['accuracy'])
  208.  
  209. model.fit(np.array(x_train), np.array(y_train),
  210. batch_size=batch_size,
  211. epochs=epochs,
  212. verbose=1,
  213. validation_data=(np.array(x_test), np.array(y_test)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement