Advertisement
Guest User

Untitled

a guest
Apr 9th, 2020
262
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.55 KB | None | 0 0
  1. def yolo_loss(args,
  2.               anchors,
  3.               num_classes,
  4.               rescore_confidence=False,
  5.               print_loss=False):
  6.     """YOLO localization loss function.
  7.  
  8.    Parameters
  9.    ----------
  10.    yolo_output : tensor
  11.        Final convolutional layer features.
  12.  
  13.    true_boxes : tensor
  14.        Ground truth boxes tensor with shape [batch, num_true_boxes, 5]
  15.        containing box x_center, y_center, width, height, and class.
  16.  
  17.    detectors_mask : array
  18.        0/1 mask for detector positions where there is a matching ground truth.
  19.  
  20.    matching_true_boxes : array
  21.        Corresponding ground truth boxes for positive detector positions.
  22.        Already adjusted for conv height and width.
  23.  
  24.    anchors : tensor
  25.        Anchor boxes for model.
  26.  
  27.    num_classes : int
  28.        Number of object classes.
  29.  
  30.    rescore_confidence : bool, default=False
  31.        If true then set confidence target to IOU of best predicted box with
  32.        the closest matching ground truth box.
  33.  
  34.    print_loss : bool, default=False
  35.        If True then use a tf.Print() to print the loss components.
  36.  
  37.    Returns
  38.    -------
  39.    mean_loss : float
  40.        mean localization loss across minibatch
  41.    """
  42.     (yolo_output, true_boxes, detectors_mask, matching_true_boxes) = args
  43.     num_anchors = len(anchors)
  44.     object_scale = 5
  45.     no_object_scale = 1
  46.     class_scale = 1
  47.     coordinates_scale = 1
  48.     pred_xy, pred_wh, pred_confidence, pred_class_prob = yolo_head(
  49.         yolo_output, anchors, num_classes)
  50.  
  51.     # Unadjusted box predictions for loss.
  52.     # TODO: Remove extra computation shared with yolo_head.
  53.     yolo_output_shape = K.shape(yolo_output)
  54.     feats = K.reshape(yolo_output, [
  55.         -1, yolo_output_shape[1], yolo_output_shape[2], num_anchors,
  56.         num_classes + 5
  57.     ])
  58.     pred_boxes = K.concatenate(
  59.         (K.sigmoid(feats[..., 0:2]), feats[..., 2:4]), axis=-1)
  60.  
  61.     # TODO: Adjust predictions by image width/height for non-square images?
  62.     # IOUs may be off due to different aspect ratio.
  63.  
  64.     # Expand pred x,y,w,h to allow comparison with ground truth.
  65.     # batch, conv_height, conv_width, num_anchors, num_true_boxes, box_params
  66.     pred_xy = K.expand_dims(pred_xy, 4)
  67.     pred_wh = K.expand_dims(pred_wh, 4)
  68.  
  69.     pred_wh_half = pred_wh / 2.
  70.     pred_mins = pred_xy - pred_wh_half
  71.     pred_maxes = pred_xy + pred_wh_half
  72.  
  73.     true_boxes_shape = K.shape(true_boxes)
  74.  
  75.     # batch, conv_height, conv_width, num_anchors, num_true_boxes, box_params
  76.     true_boxes = K.reshape(true_boxes, [
  77.         true_boxes_shape[0], 1, 1, 1, true_boxes_shape[1], true_boxes_shape[2]
  78.     ])
  79.     true_xy = true_boxes[..., 0:2]
  80.     true_wh = true_boxes[..., 2:4]
  81.  
  82.     # Find IOU of each predicted box with each ground truth box.
  83.     true_wh_half = true_wh / 2.
  84.     true_mins = true_xy - true_wh_half
  85.     true_maxes = true_xy + true_wh_half
  86.  
  87.     intersect_mins = K.maximum(pred_mins, true_mins)
  88.     intersect_maxes = K.minimum(pred_maxes, true_maxes)
  89.     intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.)
  90.     intersect_areas = intersect_wh[..., 0] * intersect_wh[..., 1]
  91.  
  92.     pred_areas = pred_wh[..., 0] * pred_wh[..., 1]
  93.     true_areas = true_wh[..., 0] * true_wh[..., 1]
  94.  
  95.     union_areas = pred_areas + true_areas - intersect_areas
  96.     iou_scores = intersect_areas / union_areas
  97.  
  98.     # Best IOUs for each location.
  99.     best_ious = K.max(iou_scores, axis=4)  # Best IOU scores.
  100.     best_ious = K.expand_dims(best_ious)
  101.  
  102.     # A detector has found an object if IOU > thresh for some true box.
  103.     object_detections = K.cast(best_ious > 0.6, K.dtype(best_ious))
  104.  
  105.     # TODO: Darknet region training includes extra coordinate loss for early
  106.     # training steps to encourage predictions to match anchor priors.
  107.  
  108.     # Determine confidence weights from object and no_object weights.
  109.     # NOTE: YOLO does not use binary cross-entropy here.
  110.     no_object_weights = (no_object_scale * (1 - object_detections) *
  111.                          (1 - detectors_mask))
  112.     no_objects_loss = no_object_weights * K.square(-pred_confidence)
  113.  
  114.     if rescore_confidence:
  115.         objects_loss = (object_scale * detectors_mask *
  116.                         K.square(best_ious - pred_confidence))
  117.     else:
  118.         objects_loss = (object_scale * detectors_mask *
  119.                         K.square(1 - pred_confidence))
  120.     confidence_loss = objects_loss + no_objects_loss
  121.  
  122.     # Classification loss for matching detections.
  123.     # NOTE: YOLO does not use categorical cross-entropy loss here.
  124.     matching_classes = K.cast(matching_true_boxes[..., 4], 'int32')
  125.     matching_classes = K.one_hot(matching_classes, num_classes)
  126.     classification_loss = (class_scale * detectors_mask *
  127.                            K.square(matching_classes - pred_class_prob))
  128.  
  129.     # Coordinate loss for matching detection boxes.
  130.     matching_boxes = matching_true_boxes[..., 0:4]
  131.     coordinates_loss = (coordinates_scale * detectors_mask *
  132.                         K.square(matching_boxes - pred_boxes))
  133.  
  134.     confidence_loss_sum = K.sum(confidence_loss)
  135.     classification_loss_sum = K.sum(classification_loss)
  136.     coordinates_loss_sum = K.sum(coordinates_loss)
  137.     total_loss = 0.5 * (
  138.         confidence_loss_sum + classification_loss_sum + coordinates_loss_sum)
  139.     if print_loss:
  140.         total_loss = tf.Print(
  141.             total_loss, [
  142.                 total_loss, confidence_loss_sum, classification_loss_sum,
  143.                 coordinates_loss_sum
  144.             ],
  145.             message='yolo_loss, conf_loss, class_loss, box_coord_loss:')
  146.  
  147.     return total_loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement