Advertisement
Guest User

Untitled

a guest
Jun 17th, 2019
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 13.85 KB | None | 0 0
  1. import numpy as np
  2. import sys
  3. from PIL import Image, ImageDraw, ImageEnhance, ImageColor
  4.  
  5. class Rpn:
  6. def __init__(self, tf, resnet_model, image_height, image_width, in_channels, mid_channels, base_anchor_size, anchor_ratios, anchor_scales, subsample_rate):
  7. """
  8. in_channels for resnet: 512
  9. mid_channels: 512
  10. base_anchor_size: 128
  11. anchor_ratios: [[1,1],[1,2],[2,1]]
  12. anchor_scales: [1,2,3,4]
  13. subsample_rate: (image_size//feature_map_size), ex. (1000//32) --> 31
  14. """
  15. self.tf = tf
  16. self.resnet_model = resnet_model
  17. self.resnet_model_inputs = self.resnet_model.input
  18. self.resnet_model_feature_map = resnet_model.output
  19. self.image_height = image_height
  20. self.image_width = image_width
  21. self.in_channels = in_channels
  22. self.mid_channels = mid_channels
  23. self.base_anchor_size = base_anchor_size
  24. self.anchor_ratios = anchor_ratios
  25. self.anchor_scales = anchor_scales
  26. self.subsample_rate = subsample_rate
  27.  
  28.  
  29. self.anchor_boxes = self.generate_anchor_boxes(self.base_anchor_size, self.anchor_ratios, self.anchor_scales)
  30. self.anchor_boxes_over_image = self.generate_anchor_boxes_over_image(self.anchor_boxes, self.image_height, self.image_width, self.subsample_rate)
  31.  
  32. self.model = self.create_model()
  33.  
  34. def get_model(self):
  35. return self.model
  36.  
  37. def create_model(self):
  38. rpn_conv = self.tf.keras.layers.SeparableConv2D(filters=512, kernel_size=[3, 3], strides=[1, 1], padding="same", data_format="channels_last", activation="relu")(self.resnet_model.output)
  39. rpn_class_score = self.tf.keras.layers.SeparableConv2D(filters=24, kernel_size=[1, 1], strides=[1, 1], padding="valid", data_format="channels_last", activation="softmax")(rpn_conv)
  40. rpn_class_score_shape = rpn_class_score.get_shape().as_list()
  41. rpn_class_score = self.tf.keras.layers.Reshape(rpn_class_score_shape[1:3] + [rpn_class_score_shape[3]//2, 2])(rpn_class_score)
  42.  
  43. rpn_bbox_pred = self.tf.keras.layers.SeparableConv2D(filters=48, kernel_size=[1, 1], strides=[1, 1], padding="valid", data_format="channels_last")(rpn_conv)
  44. rpn_bbox_pred_shape = rpn_bbox_pred.get_shape().as_list()
  45. rpn_bbox_pred = self.tf.keras.layers.Reshape(rpn_bbox_pred_shape[1:3] + [rpn_bbox_pred_shape[3]//4, 4])(rpn_bbox_pred)
  46.  
  47. rpn_model = self.tf.keras.Model(inputs=self.resnet_model.input, outputs=[rpn_class_score, rpn_bbox_pred])
  48.  
  49. self.optimizer = self.tf.keras.optimizers.Nadam(0.001)
  50. #rpn_model.compile(optimizer=self.optimizer, loss=[self.tf.keras.losses.BinaryCrossentropy(),self.tf.keras.losses.Huber()], metrics=['binary_accuracy', 'MeanSquaredError'])
  51. return rpn_model
  52.  
  53. def get_iou(self, anchor_box_predictions, ground_truth_bounding_boxes, giou=False):
  54. """ Predicted and ground truth bounding box coordinates """
  55. anchor_box_predictions_center_x, anchor_box_predictions_center_y, anchor_box_predictions_width, anchor_box_predictions_height = anchor_box_predictions[:,:,:,:,0:1], anchor_box_predictions[:,:,:,:,1:2], anchor_box_predictions[:,:,:,:,2:3], anchor_box_predictions[:,:,:,:,3:4]
  56. anchor_box_predictions_x1 = anchor_box_predictions_center_x - (anchor_box_predictions_width / 2)
  57. anchor_box_predictions_x2 = anchor_box_predictions_center_x + (anchor_box_predictions_width / 2)
  58. anchor_box_predictions_y1 = anchor_box_predictions_center_y - (anchor_box_predictions_height / 2)
  59. anchor_box_predictions_y2 = anchor_box_predictions_center_y + (anchor_box_predictions_height / 2)
  60.  
  61. """ For the predicted box ensure x2>x1 and y2>y1: """
  62. anchor_box_predictions_x1, anchor_box_predictions_x2 = np.minimum(anchor_box_predictions_x1, anchor_box_predictions_x2), np.maximum(anchor_box_predictions_x1, anchor_box_predictions_x2)
  63. anchor_box_predictions_y1, anchor_box_predictions_y2 = np.minimum(anchor_box_predictions_y1, anchor_box_predictions_y2), np.maximum(anchor_box_predictions_y1, anchor_box_predictions_y2)
  64.  
  65. """ Flatten ground truth boxes sorta """
  66. concat = np.concatenate(ground_truth_bounding_boxes, axis=0)
  67. ground_truth_bounding_boxes_center_x, ground_truth_bounding_boxes_center_y, ground_truth_bounding_boxes_width, ground_truth_bounding_boxes_height = (concat[:, 0:1]).reshape((1,1,1,1,-1)), (concat[:, 1:2]).reshape((1,1,1,1,-1)), (concat[:, 2:3]).reshape((1,1,1,1,-1)), (concat[:, 3:4]).reshape((1,1,1,1,-1))
  68.  
  69. ground_truth_bounding_boxes_x1 = ground_truth_bounding_boxes_center_x - (ground_truth_bounding_boxes_width / 2)
  70. ground_truth_bounding_boxes_x2 = ground_truth_bounding_boxes_center_x + (ground_truth_bounding_boxes_width / 2)
  71. ground_truth_bounding_boxes_y1 = ground_truth_bounding_boxes_center_y - (ground_truth_bounding_boxes_height / 2)
  72. ground_truth_bounding_boxes_y2 = ground_truth_bounding_boxes_center_y + (ground_truth_bounding_boxes_height / 2)
  73.  
  74.  
  75.  
  76. """ Get areas of boxes """
  77. anchor_box_predictions_area = (anchor_box_predictions_x2 - anchor_box_predictions_x1) * (anchor_box_predictions_y2 - anchor_box_predictions_y1)
  78. ground_truth_bounding_boxes_area = (ground_truth_bounding_boxes_x2 - ground_truth_bounding_boxes_x1) * (ground_truth_bounding_boxes_y2 - ground_truth_bounding_boxes_y1)
  79.  
  80. """ Calculate intersection between prediction boxes and ground truth boxes """
  81.  
  82. x1_intersection, x2_intersection = np.maximum(anchor_box_predictions_x1, ground_truth_bounding_boxes_x1), np.minimum(anchor_box_predictions_x2, ground_truth_bounding_boxes_x2)
  83. y1_intersection, y2_intersection = np.maximum(anchor_box_predictions_y1, ground_truth_bounding_boxes_y1), np.minimum(anchor_box_predictions_y2, ground_truth_bounding_boxes_y2)
  84.  
  85.  
  86. """ Calculate intersection area """
  87. intersection = np.where(np.logical_and(x2_intersection > x1_intersection, y2_intersection > y1_intersection), (x2_intersection - x1_intersection) * (y2_intersection - y1_intersection), 0)
  88.  
  89. """ Find the coordinates of smallest enclosing box (union) """
  90. x1_coord, x2_coord = np.minimum(anchor_box_predictions_x1, ground_truth_bounding_boxes_x1), np.maximum(anchor_box_predictions_x2, ground_truth_bounding_boxes_x2)
  91. y1_coord, y2_coord = np.minimum(anchor_box_predictions_y1, ground_truth_bounding_boxes_y1), np.maximum(anchor_box_predictions_y2, ground_truth_bounding_boxes_y2)
  92.  
  93. """ Get area of union box """
  94. union_area = (x2_coord - x1_coord) * (y2_coord - y1_coord)
  95.  
  96. """ Intersection over union """
  97. iou = intersection / (anchor_box_predictions_area + ground_truth_bounding_boxes_area - intersection)
  98. Giou = iou - ((union_area-(anchor_box_predictions_area + ground_truth_bounding_boxes_area - intersection))/union_area)
  99.  
  100. indices = [len(bb) for bb in ground_truth_bounding_boxes]
  101. indices_end = [np.sum(indices[:i + 1]) for i in range(len(indices))]
  102. indices_start = np.insert(indices_end, 0, 0)[:-1]
  103.  
  104. """ get iou, and giou with respect to correct bounding boxes for each batch image, also get those correct bounding boxes """
  105. iou_true = np.zeros(iou.shape[:-1])
  106. Giou_true = np.zeros(Giou.shape[:-1])
  107. bounding_boxes_true = np.zeros(anchor_box_predictions.shape)
  108. for index, (s, e) in enumerate(zip(indices_start, indices_end)):
  109. iou_true[index] = np.amax(iou[index,:,:,:, s:e], axis=3)
  110. Giou_true[index] = np.amax(Giou[index,:,:,:, s:e], axis=3)
  111. if giou is True:
  112. bounding_boxes_true[index] = (ground_truth_bounding_boxes[index][np.argmax(Giou[index,:,:,:, s:e], axis=3)])[:,:,:,:4]
  113. else:
  114. bounding_boxes_true[index] = (ground_truth_bounding_boxes[index][np.argmax(iou[index,:,:,:, s:e], axis=3)])[:,:,:,:4]
  115.  
  116.  
  117.  
  118. return (Giou_true,bounding_boxes_true) if giou else (iou_true, bounding_boxes_true)
  119.  
  120.  
  121. def train_model(self, inputs, ground_truth_bounding_boxes):
  122. batch_size = inputs.shape[0]
  123. predictions = self.model.predict(inputs)
  124. object_predictions = predictions[0]
  125. anchor_box_change_predictions = predictions[1]
  126. anchor_box_predictions = self.anchor_boxes_over_image + anchor_box_change_predictions
  127. anchor_boxes_giou, anchor_boxes_nearest_bounding_box = self.get_iou(anchor_box_predictions, ground_truth_bounding_boxes, giou=True)
  128. object_prediction_labels = np.zeros(object_predictions.shape[:-1])
  129. object_prediction_labels[anchor_boxes_giou > 0.5] += 1
  130. print(object_predictions.shape, anchor_box_change_predictions.shape, object_prediction_labels.shape)
  131. print([x.shape for x in self.model.trainable_weights])
  132. sys.exit()
  133.  
  134.  
  135. mini_batch_sample_size = 256*batch_size
  136. object_predictions_flat = object_predictions.flatten()
  137. object_prediction_labels_flat = object_prediction_labels.flatten()
  138. anchor_box_predictions_flat = anchor_box_predictions.reshape((-1, 4))
  139. anchor_boxes_nearest_bounding_box_flat = anchor_boxes_nearest_bounding_box.reshape((-1, 4))
  140.  
  141.  
  142. """Shuffle flattened y_pred and y_true all in the same order"""
  143. random_indices = np.random.choice(object_predictions_flat.shape[0], object_predictions_flat.shape[0], replace=False)
  144. object_predictions_flat = object_predictions_flat[random_indices]
  145. object_prediction_labels_flat = object_prediction_labels_flat[random_indices]
  146. anchor_box_predictions_flat = anchor_box_predictions_flat[random_indices]
  147. anchor_boxes_nearest_bounding_box_flat = anchor_boxes_nearest_bounding_box_flat[random_indices]
  148.  
  149. """ Sort in ascending order from background to foreground """
  150. ind_sorted = np.argsort(object_prediction_labels_flat)
  151. object_predictions_flat = object_predictions_flat[ind_sorted]
  152. object_prediction_labels_flat = object_prediction_labels_flat[ind_sorted]
  153. anchor_box_predictions_flat = anchor_box_predictions_flat[ind_sorted]
  154. anchor_boxes_nearest_bounding_box_flat = anchor_boxes_nearest_bounding_box_flat[ind_sorted]
  155.  
  156. """ Get 128 background anchors and 128 foreground anchors and merge them into a single batch """
  157. split_amount = mini_batch_sample_size // 2
  158. #Background
  159. background_object_predictions_flat = object_predictions_flat[:split_amount]
  160. background_object_prediction_labels_flat = object_prediction_labels_flat[:split_amount]
  161.  
  162. #Foreground
  163. foreground_object_predictions_flat = object_predictions_flat[-split_amount:]
  164. foreground_object_prediction_labels_flat = object_prediction_labels_flat[-split_amount:]
  165. foreground_anchor_box_predictions_flat = anchor_box_predictions_flat[-split_amount:]
  166. foreground_anchor_boxes_nearest_bounding_box_flat = anchor_boxes_nearest_bounding_box_flat[-split_amount:]
  167.  
  168.  
  169. #Merge
  170. final_object_predictions_flat = np.concatenate((background_object_predictions_flat,foreground_object_predictions_flat))
  171. final_object_prediction_labels_flat = np.concatenate((background_object_prediction_labels_flat, foreground_object_prediction_labels_flat))
  172. #Take only foreground anchor and the closest ground truth boxes from the mini batch
  173. final_anchor_box_predictions_flat = foreground_anchor_box_predictions_flat[np.nonzero(foreground_object_prediction_labels_flat)]
  174. final_anchor_boxes_nearest_bounding_box_flat = foreground_anchor_boxes_nearest_bounding_box_flat[np.nonzero(foreground_object_prediction_labels_flat)]
  175.  
  176.  
  177.  
  178. #====================================================WHAT DO I DO FROM HERE?????????????????====================================================
  179.  
  180. class_binary_loss = self.tf.keras.losses.BinaryCrossentropy
  181. huber = self.tf.keras.losses.Huber(reduction=self.tf.keras.losses.Reduction.NONE)
  182. anchor_box_regression_loss = huber(final_anchor_boxes_nearest_bounding_box_flat, final_anchor_box_predictions_flat)
  183. print(anchor_box_regression_loss.shape)
  184.  
  185. #print((tf.keras.optimizers.get_weights()).shape)
  186. #print([x.shape for x in self.model.trainable_weights])
  187. lo = class_binary_loss(final_object_prediction_labels_flat, final_object_predictions_flat)
  188. #print(lo.shape)
  189. #print(lo.shape)
  190. g = self.tf.keras.backend.gradients(self.tf.ones([1], tf.int32), self.model.trainable_weights[:-3])
  191. #self.optimizer.minimize(lo, self.model.trainable_weights[:-3])
  192. #self.optimizer.minimize(anchor_box_regression_loss, self.model.trainable_weights[:-6]+self.model.trainable_weights[-3:])
  193. #print(class_binary_loss, anchor_box_regression_loss)
  194.  
  195.  
  196. def generate_anchor_boxes(self, base_anchor_size, anchor_ratios, anchor_scales):
  197. """
  198. Every anchor box is different
  199. Number of anchor boxes: len(anchor_ratios)*len(anchor_scales)
  200. """
  201.  
  202. anchor_boxes = []
  203. for scale in anchor_scales:
  204. for aspect_ratio in anchor_ratios:
  205. height = base_anchor_size*scale*aspect_ratio[0]
  206. width = base_anchor_size*scale*aspect_ratio[1]
  207. anchor_box = [width, height]
  208. anchor_boxes.append(anchor_box)
  209.  
  210. return anchor_boxes
  211.  
  212. def generate_anchor_boxes_over_image(self, anchor_boxes, image_height, image_width, subsample_rate):
  213. """
  214. Returns an array of all the anchor boxes, each with shape [anchor_box_center_x, anchor_box_center_y, width, height]
  215. """
  216. anchor_boxes_over_image = np.zeros(( image_height//subsample_rate, image_width//subsample_rate, len(anchor_boxes), 4))
  217. for row in range(anchor_boxes_over_image.shape[0]):
  218. for col in range(anchor_boxes_over_image.shape[1]):
  219. for anchor_idx in range(anchor_boxes_over_image.shape[2]):
  220. anchor_boxes_over_image[row,col,anchor_idx] = [(row+1)*subsample_rate, (col+1)*subsample_rate, anchor_boxes[anchor_idx][0], anchor_boxes[anchor_idx][1]]
  221. return anchor_boxes_over_image
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement