Advertisement
Guest User

Untitled

a guest
Oct 13th, 2019
130
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.89 KB | None | 0 0
  1. """
  2. Ensembling methods for object detection.
  3. """
  4.  
  5. """
  6. General Ensemble - find overlapping boxes of the same class and average their positions
  7. while adding their confidences. Can weigh different detectors with different weights.
  8. No real learning here, although the weights and iou_thresh can be optimized.
  9.  
  10. Input:
  11. - dets : List of detections. Each detection is all the output from one detector, and
  12. should be a list of boxes, where each box should be on the format
  13. [box_x, box_y, box_w, box_h, class, confidence] where box_x and box_y
  14. are the center coordinates, box_w and box_h are width and height resp.
  15. The values should be floats, except the class which should be an integer.
  16.  
  17. - iou_thresh: Threshold in terms of IOU where two boxes are considered the same,
  18. if they also belong to the same class.
  19.  
  20. - weights: A list of weights, describing how much more some detectors should
  21. be trusted compared to others. The list should be as long as the
  22. number of detections. If this is set to None, then all detectors
  23. will be considered equally reliable. The sum of weights does not
  24. necessarily have to be 1.
  25.  
  26. Output:
  27. A list of boxes, on the same format as the input. Confidences are in range 0-1.
  28. """
  29. def GeneralEnsemble(dets, iou_thresh = 0.5, weights=None):
  30. assert(type(iou_thresh) == float)
  31.  
  32. ndets = len(dets)
  33.  
  34. if weights is None:
  35. w = 1/float(ndets)
  36. weights = [w]*ndets
  37. else:
  38. assert(len(weights) == ndets)
  39.  
  40. s = sum(weights)
  41. for i in range(0, len(weights)):
  42. weights[i] /= s
  43.  
  44. out = list()
  45. used = list()
  46.  
  47. for idet in range(0,ndets):
  48. det = dets[idet]
  49. for box in det:
  50. if box in used:
  51. continue
  52.  
  53. used.append(box)
  54. # Search the other detectors for overlapping box of same class
  55. found = []
  56. for iodet in range(0, ndets):
  57. odet = dets[iodet]
  58.  
  59. if odet == det:
  60. continue
  61.  
  62. bestbox = None
  63. bestiou = iou_thresh
  64. for obox in odet:
  65. if not obox in used:
  66. # Not already used
  67. if box[4] == obox[4]:
  68. # Same class
  69. iou = computeIOU(box, obox)
  70. if iou > bestiou:
  71. bestiou = iou
  72. bestbox = obox
  73.  
  74. if not bestbox is None:
  75. w = weights[iodet]
  76. found.append((bestbox,w))
  77. used.append(bestbox)
  78.  
  79. # Now we've gone through all other detectors
  80. if len(found) == 0:
  81. new_box = list(box)
  82. new_box[5] /= ndets
  83. out.append(new_box)
  84. else:
  85. allboxes = [(box, weights[idet])]
  86. allboxes.extend(found)
  87.  
  88. xc = 0.0
  89. yc = 0.0
  90. bw = 0.0
  91. bh = 0.0
  92. conf = 0.0
  93.  
  94. wsum = 0.0
  95. for bb in allboxes:
  96. w = bb[1]
  97. wsum += w
  98.  
  99. b = bb[0]
  100. xc += w*b[0]
  101. yc += w*b[1]
  102. bw += w*b[2]
  103. bh += w*b[3]
  104. conf += w*b[5]
  105.  
  106. xc /= wsum
  107. yc /= wsum
  108. bw /= wsum
  109. bh /= wsum
  110.  
  111. new_box = [xc, yc, bw, bh, box[4], conf]
  112. out.append(new_box)
  113. return out
  114.  
  115. def getCoords(box):
  116. x1 = float(box[0]) - float(box[2])/2
  117. x2 = float(box[0]) + float(box[2])/2
  118. y1 = float(box[1]) - float(box[3])/2
  119. y2 = float(box[1]) + float(box[3])/2
  120. return x1, x2, y1, y2
  121.  
  122. def computeIOU(box1, box2):
  123. x11, x12, y11, y12 = getCoords(box1)
  124. x21, x22, y21, y22 = getCoords(box2)
  125.  
  126. x_left = max(x11, x21)
  127. y_top = max(y11, y21)
  128. x_right = min(x12, x22)
  129. y_bottom = min(y12, y22)
  130.  
  131. if x_right < x_left or y_bottom < y_top:
  132. return 0.0
  133.  
  134. intersect_area = (x_right - x_left) * (y_bottom - y_top)
  135. box1_area = (x12 - x11) * (y12 - y11)
  136. box2_area = (x22 - x21) * (y22 - y21)
  137.  
  138. iou = intersect_area / (box1_area + box2_area - intersect_area + 1e-6)
  139. return iou
  140.  
  141. if __name__=="__main__":
  142. # Toy example
  143. dets = [
  144. [[0.1, 0.1, 1.0, 1.0, 1, 0.9], [1.2, 1.4, 0.5, 1.5, 1, 0.9]],
  145. [[0.2, 0.1, 0.9, 1.1, 1, 0.8],[1.19, 1.4, 0.5, 1.5, 1, 0.9]],
  146. [[5.0,5.0,1.0,1.0,1,0.5]]
  147. ]
  148.  
  149. ens = GeneralEnsemble(dets, weights = [1.0, 0.1, 0.5])
  150. print(ens)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement